R/functions.R

Defines functions plot_biplot plot_scree get_cb_colors score_squared_cosine score_contribution tsvd_of_C_with_names get_non_zero_coefficients_as_data_frame tsvd_of_C_variant_loadings_as_data_framte separate_ID_ALT_into_two_columns get_non_zero_coefficients get_summary_metrics read_lambda_sequence_from_multiSnpnetResults load_multiSnpnetResultsFromRDataFiles prepare_multiSnpnetResults drop_bulky_stuff_in_fit_object find_best_lambda_index extract_all_metrics extract_weighted_metrics extract_weight extract_metrics check_if_metric_exists load_multiSnpnetRdata get_non_zero_lines get_non_NA_lines find_prev_iter get_rdata_path check_early_stopping_condition weighted_rowMeans safe_product plot_multisnpnet predict_multisnpnet read_pvar coef_multisnpnet check_configs_diff compute_lambda_min_ratio timeDiff y_de_standardization y_standardization MSE_coef MSE_col which_row_active row_norm2 SRRR_iterative_missing_covariates alternate_Y_glmnet time_diff initial_Y_imputation setupMultiConfigs

Documented in check_early_stopping_condition check_if_metric_exists coef_multisnpnet drop_bulky_stuff_in_fit_object extract_all_metrics extract_metrics extract_weight extract_weighted_metrics find_best_lambda_index find_prev_iter get_cb_colors get_non_NA_lines get_non_zero_coefficients get_non_zero_coefficients_as_data_frame get_non_zero_lines get_rdata_path get_summary_metrics load_multiSnpnetRdata load_multiSnpnetResultsFromRDataFiles plot_biplot plot_multisnpnet plot_scree predict_multisnpnet prepare_multiSnpnetResults read_lambda_sequence_from_multiSnpnetResults read_pvar score_contribution score_squared_cosine separate_ID_ALT_into_two_columns tsvd_of_C_variant_loadings_as_data_framte tsvd_of_C_with_names weighted_rowMeans

setupMultiConfigs <- function(configs, genotype_file, phenotype_file, phenotype_names, covariate_names,
                              nlambda, mem,
                              standardize_response, max.iter, rank, prev_iter, batch_size, save) {
  out.args <- as.list(environment())
  defaults_multi <- list(
    missing.rate = 0.1,
    MAF.thresh = 0.001,
    nCores = 1,
    glmnet.thresh = 1e-07,
    nlams.init = 10,
    nlams.delta = 5,
    vzs=TRUE, # geno.pfile vzs
    increase.size = NULL,
    standardize.variant = FALSE,
    early.stopping = TRUE,
    stopping.lag = 2,
    niter = 10,
    lambda.min.ratio = NULL,
    KKT.verbose = FALSE,
    use.glmnetPlus = NULL,
    save = FALSE,
    save.computeProduct = FALSE,
    prevIter = 0,
    results.dir = NULL,
    meta.dir = 'meta',
    save.dir = 'results',
    verbose = FALSE,
    KKT.check.aggressive.experimental = FALSE,
    gcount.basename.prefix = 'snpnet.train',
    gcount.full.prefix=NULL,
    endian="little",
    metric=NULL,
    plink2.path='plink2',
    zstdcat.path='zstdcat',
    rank = TRUE,
    is.warm.start = TRUE,
    is.A.converge = TRUE,
    thresh = 1e-7,
    MAXLEN = (2^31 - 1) / 2,
    use_safe = TRUE,
    excludeSNP = NULL,
    converge.type = 'obj'
  )
  for (name in setdiff(names(out.args), "configs")) {
    configs[[name]] <- out.args[[name]]
  }
  for (name in names(defaults_multi)) {
    if (!(name %in% names(configs))) {
      configs[[name]] <- defaults_multi[[name]]
    }
  }

  # update settings
  if(is.null(configs[['increase.size']]))  configs[['increase.size']] <- configs[['batch_size']]/2

  # We will write some intermediate files to meta.dir and save.dir.
  # those files will be deleted with snpnet::cleanUpIntermediateFiles() function.
  if (is.null(configs[['results.dir']])) configs[['results.dir']] <- tempdir(check = TRUE)
  dir.create(file.path(configs[['results.dir']], configs[["meta.dir"]]), showWarnings = FALSE, recursive = T)
  dir.create(file.path(configs[['results.dir']], configs[["save.dir"]]), showWarnings = FALSE, recursive = T)
  if(is.null(configs[['gcount.full.prefix']])) configs[['gcount.full.prefix']] <- file.path(
    configs[['results.dir']], configs[["meta.dir"]], configs['gcount.basename.prefix']
  )

  configs
}


initial_Y_imputation <- function(response, covariates, missing_response) {
  residual <- matrix(NA, nrow(response), ncol(response))
  colnames(residual) <- colnames(response)
  if (is.null(covariates)) covariates <- data.frame(intercept = rep(1, nrow(response)))
  for (k in 1:ncol(response)) {
    fit <- lm(response[, k] ~ ., data = covariates)
    pred <- predict(fit, newdata = covariates)
    response[missing_response[, k], k] <- pred[missing_response[, k]]
    residual[, k] <- response[, k] - pred
  }
  out <- list(response = response, residual = residual)
  out
}


time_diff <- function(start_time, end_time) {
  paste(round(end_time-start_time, 4), units(end_time-start_time))
}


alternate_Y_glmnet <- function(features, response, missing_response, lambda, penalty_factor, configs,
                               num_covariates, r, thresh = 1E-7, object0, W_init, B_init, A_init, glmnet_thresh = 1e-7,
                               max.iter) {
  converge <- FALSE
  features_matrix <- as.matrix(features)
  niter <- 0
  obj_values <- c()
  cat("    Start Y-C (glmnet) iteration ...\n")
  message <- "Terminated"

  if (is.null(W_init) || is.null(B_init) || is.null(A_init)) {
    CC <- NULL
  } else {
    CC <- matrix(0, ncol(features_matrix), ncol(response))
    rownames(CC) <- colnames(features_matrix)
    CC[rownames(B_init), ] <- tcrossprod(as.matrix(B_init), A_init)
    CC[rownames(W_init), ] <- as.matrix(W_init)
  }

  while (!converge || niter > max.iter) {
    niter <- niter + 1
    fit <- glmnetPlus::glmnet(features_matrix, response, family = "mgaussian", lambda = lambda, penalty.factor = penalty_factor,
                              standardize = configs[["standardize.variant"]], standardize.response = FALSE, beta0 = CC, thresh = glmnet_thresh)
    CC <- do.call(cbind, fit$beta)
    pred <- glmnetPlus::predict.mrelnet(fit, newx = features_matrix, type = "response")[, , 1]
    response[missing_response] <- pred[missing_response]
    obj_values[niter] <- sum((response-pred)^2) / 2 / nrow(response) + lambda * sum(apply(CC, 1, function(x) sqrt(sum(x^2))))
    cat("         Objective (", niter, "): ", round(obj_values[niter], digits = 8), "\n", sep = "")
    if (niter > 1) {
      delta <- obj_values[niter] - obj_values[niter-1]
    }
    if (niter > 1 && delta < thresh*object0) {
      message <- "Converged"
      converge <- TRUE
    }
  }
  cat("    Finish Y-C (glmnet) iteration:", message, "in", niter, "iterations.\n")
  colnames(CC) <- colnames(response)
  W <- CC[seq_len(num_covariates), , drop = F]
  C <- CC[(num_covariates+1):ncol(features), , drop = F]
  if (r < ncol(response)) {  ## but assumes that the number of selected variables is less than the target rank
    svd_obj <- svd(C)
    B <- svd_obj$u[, 1:r, drop = F] %*% diag(svd_obj$d[1:r], r)
    rownames(B) <- rownames(C)
    A <- svd_obj$v[, 1:r, drop = F]
    rownames(A) <- colnames(C)
  } else {
    B <- C
    A <- diag(1, nrow = r)
    rownames(A) <- colnames(C)
  }
  a0 <- fit$a0
  residuals <- response - pred
  out <- list(response = response, a0 = a0, W = W, C = C, CC = CC, B = B, A = A, residuals = residuals, obj_values = obj_values)
  out
}


SRRR_iterative_missing_covariates <- function(X, Y, Y_missing, Z, PZ, lambda, r, niter, B0, thresh = 1e-7, object0,
                                              is.warm.start = FALSE, is.A.converge = TRUE, glmnet_thresh = 1e-7, converge_type = "obj") {
  n <- nrow(X)
  p <- ncol(X)
  q <- ncol(Y)

  B <- matrix(0, ncol(X), r)
  rownames(B) <- colnames(X)
  B[rownames(B0), ] <- as.matrix(B0)

  obj_values <- rep(NA, niter)
  A_niter <- rep(NA, niter)
  message <- "Terminated"

  count <- 0

  start_BAY <- Sys.time()

  for (k in 1:niter) {
    count <- count + 1
    A_niter[k] <- 0
    # fix B, solve A

    MAXLEN <- 2^31 - 1  # deal with long vector
    ncol.chunk <- floor(MAXLEN / as.double(nrow(X)) / 4)
    numChunks <- ceiling(ncol(X) / ncol.chunk)

    for (jc in 1:numChunks) {
      idx <- ((jc-1)*ncol.chunk+1):min(jc*ncol.chunk, ncol(X))
      if (jc == 1) {
        score <- as.matrix(X[, idx] %*% B[idx, , drop = F])
      } else {
        score <- score + as.matrix(X[, idx] %*% B[idx, , drop = F])
      }
    }

    impute_iter_count <- 0
    projected_score_Z <- Z %*% (PZ %*% score)
    RS <- score - projected_score_Z
    cat("    Start Y-A iteration ...\n")
    B_norm <- sum(row_norm2(B))
    start_Y_A <- Sys.time()
    while (TRUE) {
      A_niter[k] <- A_niter[k] + 1
      impute_iter_count <- impute_iter_count + 1
      projected_Y_Z <- Z %*% (PZ %*% Y)
      RY <- Y - projected_Y_Z
      crossmat <- crossprod(RY, score)
      svd_cross <- svd(crossmat)
      if (impute_iter_count > 1) obj_old <- obj
      A <- tcrossprod(svd_cross$u, svd_cross$v)
      Y_new <- projected_Y_Z + tcrossprod(RS, A)  # implicit W
      Y[Y_missing] <- Y_new[Y_missing]

      obj <- mean((Y - Y_new)^2)/2*ncol(Y) + lambda * B_norm
      cat("         Objective (", impute_iter_count, "): ", round(obj, digits = 8), "\n", sep = "")
      if (impute_iter_count > 1) {
        delta <- obj_old - obj
      } else {
        delta <- 100
      }
      if (delta < thresh*object0 || (!is.A.converge && A_niter[k] > 0)) {
        break
      }
    }
    end_Y_A <- Sys.time()
    cat("    Finish Y-A iteration: ", impute_iter_count, ". Time elapsed: ",
        time_diff(start_Y_A, end_Y_A), "\n", sep = "")

    cat("    Start solving for B ...\n")

    ZW <- Y_new - tcrossprod(score, A)
    YA <- (Y - ZW) %*% A
    if (k == 1 && !is.warm.start) {
      mfit <- glmnetPlus::glmnet(x = X, y = YA, family = "mgaussian", standardize = F, intercept = F, lambda = lambda, thresh = glmnet_thresh)
    } else {
      mfit <- glmnetPlus::glmnet(x = X, y = YA, family = "mgaussian", standardize = F, intercept = F, lambda = lambda, beta0 = B, thresh = glmnet_thresh)
    }
    if (is.null(dim(mfit$a0))) {
      mfit$a0 <- matrix(mfit$a0, nrow = 1)
      mfit$beta <- list(mfit$beta)
    }
    beta_single <- coef(mfit, s = lambda, x = X, y = YA)
    B <- do.call(cbind, beta_single)[-1, , drop = F]

    end_B <- Sys.time()
    cat("    Finish solving for B. ", "Time elapsed: ",
        time_diff(end_Y_A, end_B), "\n", sep = "")


    if (k > 1) C_old <- C
    C <- tcrossprod(as.matrix(B), A)

    MAXLEN <- 2^31 - 1  # deal with long vector
    ncol.chunk <- floor(MAXLEN / as.double(nrow(X)) / 4)
    numChunks <- ceiling(ncol(X) / ncol.chunk)

    for (jc in 1:numChunks) {
      idx <- ((jc-1)*ncol.chunk+1):min(jc*ncol.chunk, ncol(X))
      if (jc == 1) {
        score <- as.matrix(X[, idx, drop = F] %*% B[idx, , drop = F])
      } else {
        score <- score + as.matrix(X[, idx, drop = F] %*% B[idx, , drop = F])
      }
    }

    Y_new <- ZW + tcrossprod(score, A)
    Y[Y_missing] <- Y_new[Y_missing]
    residuals <- Y - Y_new
    obj_values[k] <- 1/(2*n) * sum((residuals)^2) +
      lambda * sum(row_norm2(B))

    if (converge_type == "params") {
      if (k > 1 && (sqrt(sum((C_old - C)^2)) < thresh*sqrt(sum(C^2)))) {
        message <- "Converged"
        obj_values <- obj_values[1:k]
        A_niter <- A_niter[1:k]
        break
      }
    } else {
      if (k > 1 && abs(obj_values[k] - obj_values[k-1]) < thresh*object0) {
        message <- "Converged"
        obj_values <- obj_values[1:k]
        A_niter <- A_niter[1:k]
        break
      }
    }
  }

  end_BAY <- Sys.time()
  cat("Finish B-A-Y iteration: ", message, " after ", k, " iterations. ", "Time elapsed: ",
      time_diff(start_BAY, end_BAY), "\n", sep = "")

  coef_score_Z <- PZ %*% score
  coef_Y_Z <- PZ %*% Y
  W_full <- coef_Y_Z - tcrossprod(coef_score_Z, A) ## can be different from W on exit of inner loop
  a0 <- W_full[1, ]
  W <- W_full[-1, , drop = F]

  colnames(C) <- colnames(Y)

  out <- list(B = B, A = A, C = C, a0 = a0, W = W, obj_values = obj_values, message = message,
              niter = count, response = Y, A_niter = A_niter, residuals = residuals)
  out
}


row_norm2 <- function(X) {
  out <- apply(X, 1, function(x) sqrt(sum(x^2)))
  out
}

which_row_active <- function(X) {
  out <- names(which(apply(X, 1, function(x) any(x != 0))))
  out
}

MSE_col <- function(X) {
  MSE <- apply(X, 2, function(x) mean((x - mean(x, na.rm = T))^2, na.rm = T))
  MSE
}

# We assume that each column of Y has been standardized

MSE_coef <- function(C_hat, C) {
  C_full <- matrix(0, nrow(C), ncol(C))
  rownames(C_full) <- rownames(C)
  C_full[rownames(C_hat), ] <- C_hat
}

y_standardization <- function(response, phenotype_names, weight) {
  y_means <- c()
  y_sds <- c()
  for (name in phenotype_names) {
    y_means[name] <- mean(response[[name]], na.rm = T)
    y_sds[name] <- sd(response[[name]], na.rm = T)
    response[[name]] <- (response[[name]] - y_means[name]) / y_sds[name] * sqrt(weight[name])
  }
  out <- list(response = response, means = y_means, sds = y_sds)
  out
}

y_de_standardization <- function(response, means, sds, weight) {
  for (name in colnames(response)) {
    response[, name] <- sds[name] * response[, name] / sqrt(weight[name]) + means[name]
  }
  response
}

timeDiff <- function(start.time, end.time = NULL) {
  if (is.null(end.time)) end.time <- Sys.time()
  paste(round(end.time-start.time, 4), units(end.time-start.time))
}

# used to compute the new lambda.min.ratio if we want to extend the original lambda sequence
compute_lambda_min_ratio <- function(nlambda.new, nlambda = 100, ratio = 0.01) {
  exp((nlambda.new-1)/(nlambda-1)*log(ratio))
}

check_configs_diff <- function(old_configs, new_configs) {
  msg <- ""
  for (name in intersect(names(old_configs), names(new_configs))) {
    if (!identical(old_configs[[name]], new_configs[[name]])) {
      msg <- paste0(
        msg,
        cat("Changed config for ", name, ": ", old_configs[[name]], " -> ", new_configs[[name]], "\n", sep = "")
        )
    }
  }
  for (name in setdiff(names(old_configs), names(new_configs))) {
    if (!identical(old_configs[[name]], new_configs[[name]])) {
      msg <- paste0(
        msg,
        cat("Deleted config for ", name, ": ", old_configs[[name]], "\n", sep = "")
      )
    }
  }
  for (name in setdiff(names(new_configs), names(old_configs))) {
    if (!identical(old_configs[[name]], new_configs[[name]])) {
      msg <- paste0(
        msg,
        cat("Added config for ", name, ": ", new_configs[[name]], "\n", sep = "")
      )
    }
  }
  if (msg != "") {
    warning(msg)
  }
}

#' Extract Coefficients from the Fitted Object or File
#'
#' @param fit Fit object returned from multisnpnet
#' @param fit_path Path to the file that saves the fit object
#' @param idx Lambda indices where the coefficients are requested
#' @param uv Boolean. Whether U, V are used to represent the decomposed matrices or B, A.
#'
#' @return List of coefficients where each element is a list of one type of coefficients over the
#'   provided lambda indices.
#'
#' @export
coef_multisnpnet <- function(fit = NULL, fit_path = NULL, idx = NULL, uv = TRUE) {
  if (is.null(fit) && is.null(fit_path)) {
    stop("Either fit object or file path to the saved object should be provided.\n")
  }
  if (is.null(fit)) fit <- readRDS(file = fit_path)
  if (is.null(idx)) idx <- seq_along(fit)
  fit <- fit[idx]
  a0 <- lapply(fit, function(x) x$a0)
  W <- lapply(fit, function(x) x$W)
  B <- lapply(fit, function(x) x$B)
  A <- lapply(fit, function(x) x$A)
  if (uv) {
    out <- list(a0 = a0, W = W, U = B, V = A, idx = idx)
  } else {
    out <- list(a0 = a0, W = W, B = B, A = A, idx = idx)
  }
  out
}

#' Read the list of genetic variants present in the dataset.
#'
#' @param genotype_file Path to the new suite of genotype files. genotype_file.pvar.zst must exist.
#' @param zstdcat_path Path to zstdcat program, needed when loading variants
#'
#' @return A data frame containig the list of variants present in the genetic dataset.
#'
#' @export
read_pvar <- function(genotype_file, zstdcat_path = 'zstd'){
    dplyr::rename(data.table::fread(cmd=paste0(zstdcat_path, ' ', paste0(genotype_file, '.pvar.zst'))), 'CHROM'='#CHROM')

}

#' Predict from the Fitted Object or File
#'
#' @param fit List of fit object returned from multisnpnet.
#' @param saved_path Path to the file that saves the fit object. The full path is constructed as ${saved_path}${idx}.RData.
#' @param new_genotype_file Path to the new suite of genotype files. genotype_file.{pgen, psam, pvar.zst}.
#'   must exist.
#' @param new_phenotype_file Path to the phenotype. The header must include FID, IID.
#' @param idx Lambda indices where the coefficients are requested.
#' @param covariate_names Character vector of the names of the adjustment covariates.
#' @param split_col Name of the split column. If NULL, all samples will be used.
#' @param split_name Vector of split labels where prediction is to be made.
#' @param binary_phenotypes Vector of names of the binary phenotypes. If training split is provided,
#'   logistic regression will be refitted on the covariates and linear prediction score (from
#'   multivariate fit) and the final prediction updated. In addition, AUC will be computed for
#'   binary phenotypes.
#' @param zstdcat_path Path to zstdcat program, needed when loading variants
#'
#' @return A list containing the prediction and the resopnse for which the prediction is made.
#'
#' @export
predict_multisnpnet <- function(fit = NULL, saved_path = NULL, new_genotype_file, new_phenotype_file,
                                idx = NULL, covariate_names = NULL, split_col = NULL, split_name = NULL,
                                binary_phenotypes = NULL, zstdcat_path = "zstdcat") {
  if (is.null(fit) && is.null(saved_path)) {
    stop("Either fit object or file path to the saved object should be provided.\n")
  }
  if (is.null(fit) && is.null(idx)) {
    stop("Lambda indices on which prediction is made must be provided.\n")
  }
  if (is.null(fit)) {
    # last <- max(idx)
    # latest_result <- paste0(saved_path, last, ".RData")
    # e <- new.env()
    # load(latest_result, envir = e)
    # feature_names <- e$active
    #
    fit <- vector("list", length(idx))
    for (i in seq_along(idx)) {
      e <- new.env()
      load(paste0(saved_path, idx[i], ".RData"), envir = e)
      fit[[i]] <- e$fit
    }
  }

  fit <- fit[!sapply(fit, is.null)]
  stats <- fit[[length(fit)]][["stats"]]

  weight <- fit[[length(fit)]][["weight"]]
  phenotype_names <- names(fit[[length(fit)]][["a0"]])
  is_full_rank <- (ncol(as.matrix(fit[[length(fit)]][["B"]])) == ncol(as.matrix(fit[[length(fit)]][["C"]])))

  if ("std_obj" %in% names(fit[[length(fit)]])) {
    std_obj <- fit[[length(fit)]][["std_obj"]]
  } else {
    means_std <- rep(0, length(phenotype_names))
    names(means_std) <- phenotype_names
    sds_std <- rep(1, length(phenotype_names))
    names(sds_std) <- phenotype_names
    std_obj <- list(means = means_std, sds = sds_std)
  }

  covariate_names_fit <- rownames(fit[[length(fit)]][["W"]])
  if (!setequal(covariate_names_fit, covariate_names)) {
    stop("Unequal covariate sets in the fit and the argument.\n",
         "Fit: ", covariate_names_fit, "\n",
         "Argument: ", covariate_names, "\n")
  }

  ids <- list()
  ids[["psam"]] <- snpnet:::readIDsFromPsam(paste0(new_genotype_file, '.psam'))

  configs <- list(zstdcat.path = zstdcat_path)
  phe_master <- snpnet::readPheMaster(new_phenotype_file, ids[['psam']], NULL, covariate_names, phenotype_names, NULL, split_col, configs)

  if (is.null(split_col)) {
    split_name <- "train"
    ids[["train"]] <- phe_master$ID
  } else {
    for (split in split_name) {
      ids[[split]] <- phe_master$ID[phe_master[[split_col]] == split]
      if (length(ids[[split]]) == 0) {
        warning(paste("Split", split, "doesn't exist in the phenotype file. Excluded from prediction.\n"))
        split_name <- setdiff(split_name, split)
      }
    }
  }

  if ("train" %in% split_name) {
    split_name <- c("train", setdiff(split_name, "train"))
  }

  for (split in split_name) {
    ids[[split]] <- ids[[split]][!is.na(ids[[split]])]
  }

  phe <- list()
  for (split in split_name) {
    ids_loc <- match(ids[[split]], phe_master[["ID"]])
    phe[[split]] <- phe_master[ids_loc]
  }

  covariates <- list()

  for (split in split_name) {
    if (length(covariate_names) > 0) {
      covariates[[split]] <- phe[[split]][, covariate_names, with = FALSE]
    } else {
      covariates[[split]] <- NULL
    }
  }

  vars <- dplyr::mutate(read_pvar(new_genotype_file, zstdcat_path), VAR_ID=paste(ID, ALT, sep='_'))$VAR_ID
  pvar <- pgenlibr::NewPvar(paste0(new_genotype_file, '.pvar.zst'))
  chr <- list()
  for (split in split_name) {
    chr[[split]] <- pgenlibr::NewPgen(paste0(new_genotype_file, '.pgen'), pvar = pvar, sample_subset = match(ids[[split]], ids[["psam"]]))
  }
  pgenlibr::ClosePvar(pvar)

  feature_names <- c()
  for (i in seq_along(fit)) {
    feature_names <- c(feature_names, which_row_active(fit[[i]]$B))
  }
  feature_names <- unique(feature_names)

  features <- list()
  for (split in split_name) {
    if (!is.null(covariates[[split]]) && is_full_rank) {
      features[[split]] <- data.table::data.table(covariates[[split]])
      features[[split]][, (feature_names) := snpnet:::prepareFeatures(chr[[split]], vars, feature_names, stats)]
    } else {
      features[[split]] <- snpnet:::prepareFeatures(chr[[split]], vars, feature_names, stats)
    }
  }

  pred <- list()
  R2 <- list()
  if (length(binary_phenotypes) > 0) AUC <- list()
  for (split in split_name) {
    pred[[split]] <- array(dim = c(nrow(features[[split]]), length(fit[[length(fit)]][["a0"]]), length(fit)),
                  dimnames = list(ids[[split]], phenotype_names, seq_along(fit)))
    R2[[split]] <- array(dim = c(length(fit[[length(fit)]][["a0"]]), length(fit)),
                         dimnames = list(phenotype_names, seq_along(fit)))
    if (length(binary_phenotypes) > 0) {
      AUC[[split]] <- array(dim = c(length(binary_phenotypes), length(fit)),
                            dimnames = list(binary_phenotypes, seq_along(fit)))
    }
  }

  response <- list()
  variance <- list()
  for (split in split_name) {
    response[[split]] <- as.matrix(phe[[split]][, phenotype_names, with = F])
    if (length(binary_phenotypes) > 0) {
      for (bphe in binary_phenotypes) {
        if (all(response[[split]][, bphe] >= 1, na.rm = T) && all(response[[split]][, bphe] <= 2, na.rm = T)) {
          response[[split]][, binary_phenotypes] <- response[[split]][, binary_phenotypes] - 1
        }
      }
    }
    variance[[split]] <- apply(response[[split]], 2, function(x) mean((x - mean(x, na.rm = T))^2, na.rm = T))
  }

  for (i in seq_along(fit)) {
    if (length(binary_phenotypes) > 0 && "train" %in% split_name) glmfit_bin <- list()
    for (split in split_name) {
      if (!is.null(covariates[[split]]) && !is_full_rank) {
        active_vars <- which_row_active(fit[[i]]$C)
        if (length(active_vars) > 0) {
          features_single <- as.matrix(features[[split]][, active_vars, with = F])
        } else {
          features_single <- matrix(0, nrow = nrow(features[[split]]), ncol = 0)
        }
        pred_var <- safe_product(features_single, fit[[i]]$C[active_vars, , drop = F])
        if (is.null(fit[[i]]$W) || nrow(fit[[i]]$W) < ncol(covariates[[split]])) {
          pred_single <- sweep(pred_var, 2, fit[[i]]$a0, FUN = "+")
        } else {
          pred_single <- as.matrix(covariates[[split]]) %*% fit[[i]]$W + sweep(pred_var, 2, fit[[i]]$a0, FUN = "+")
        }
      } else {
        active_vars <- which_row_active(fit[[i]]$CC)
        if (length(active_vars) > 0) {
          features_single <- as.matrix(features[[split]][, active_vars, with = F])
        } else {
          feature_single <- matrix(0, nrow = nrow(features[[split]]), ncol = 0)
        }
        pred_var <- safe_product(features_single, fit[[i]]$CC[active_vars, , drop = F])
        pred_single <- sweep(pred_var, 2, fit[[i]]$a0, FUN = "+")
      }
      pred_single <- as.matrix(pred_single)
      colnames(pred_single) <- colnames(fit[[i]]$C)
      pred_single <- as.matrix(y_de_standardization(pred_single, std_obj$means, std_obj$sds, weight))
      R2[[split]][, i] <- 1 - apply((pred_single - response[[split]])^2, 2, mean, na.rm = T) / variance[[split]]
      if (length(binary_phenotypes) > 0) {
        for (bphe in binary_phenotypes) {
          if ("train" %in% split_name) {
            if (split == "train") {
              data_logistic_train <- data.frame(response = response[[split]][, bphe], covariates[[split]], score = pred_single[, bphe])
              glmfit_bin[[bphe]] <- glm(response ~ ., data = data_logistic_train, family = binomial())
            }
            data_logistic_split <- data.frame(covariates[[split]], score = pred_single[, bphe])
            pred_prob_split <- predict(glmfit_bin[[bphe]], newdata = data_logistic_split, type = "response")
            pred_single[, bphe] <- pred_prob_split
          }
          not_missing <- !is.na(response[[split]][, bphe])
          pred_obj <- ROCR::prediction(pred_single[not_missing, bphe], response[[split]][not_missing, bphe])
          auc_obj <- ROCR::performance(pred_obj, measure = 'auc')
          AUC[[split]][bphe, i] <- auc_obj@y.values[[1]]
        }
      }
      pred[[split]][, , i] <- pred_single
    }
  }

  out <- list(prediction = pred, response = response, R2 = R2)
  if (length(binary_phenotypes) > 0) out[["AUC"]] <- AUC

  out
}


#' Make plots of the multisnpnet results
#'
#' For 50th lambda, the reduced-rank results are saved at
#' results_dir[i]/${rank_prefix}[i]${rank}[j]/${file_prefix}50${file_suffix}, and snpnet results are
#' saved in ${snpnet_dir}/${phenotype}/${snpnet_subdir}/${snpnet_prefix}50${snpnet_suffix}
#'
#' @param results_dir Character vector each specifies the parent directory of one type (e.g. exact,
#'   lazy) of results
#' @param rank_prefix Character vector each specifies the prefix of the subdirectories holding
#'   per-rank results separately
#' @param type Character vector each specifies the type of the results, e.g., exact, lazy
#' @param rank Numeric vector of ranks for which the results are available
#' @param file_prefix Character vector each specifies the prefix of the result file
#' @param file_suffix Character vector each specifies the suffix of the result file
#' @param snpnet_dir Parent directory the snpnet results
#' @param snpnet_subdir Name of the subdirectory hosting snpnet results. The result files are saved
#'   in
#' @param save_dir Directory to save the plots. If NULL, no plots are generated but the list of plot
#'   objects is returned
#' @param train_name Name of the object storing the training metric
#' @param train_name Name of the object storing the validation metric
#' @param xlim The x limits (x1, x2) of the plots
#' @param ylim The y limits (y1, y2) of the plots
#'
#' @import ggplot2
#'
#' @export
plot_multisnpnet <- function(results_dir, rank_prefix, type, rank,
                             file_prefix, file_suffix,
                             snpnet_dir = NULL, snpnet_subdir = NULL, snpnet_prefix = NULL, snpnet_suffix = NULL,
                             save_dir = NULL, train_name = "metric_train", val_name = "metric_val", test_name = NULL, metric_name = "R2",
                             train_bin_name = "AUC_train", val_bin_name = "AUC_val", test_bin_name = "AUC_test", metric_bin_name = "AUC",
                             xlim = c(NA, NA), ylim = c(NA, NA), mapping_phenotype = NULL) {
  if (!is.null(save_dir)) dir.create(save_dir, recursive = T)
  data_metric_full <- NULL
  bin_names <- c()
  for (dir_idx in seq_along(results_dir)) {
    for (r in rank) {
      dir_rank <- file.path(results_dir[dir_idx], paste0(rank_prefix[dir_idx], r))
      files_in_dir <- list.files(dir_rank)
      result_files <- files_in_dir[startsWith(files_in_dir, file_prefix[dir_idx])]
      max_iter <- max(as.numeric(gsub(file_suffix[dir_idx], "", gsub(pattern = file_prefix[dir_idx], "", result_files))))
      latest_result <- file.path(dir_rank, paste0(file_prefix[dir_idx], max_iter, file_suffix[dir_idx]))

      myenv <- new.env()
      load(latest_result, envir = myenv)
      metric_train <- myenv[[train_name]]
      metric_val <- myenv[[val_name]]
      if (!is.null(test_name)) {
        if (!(test_name %in% names(myenv))) {
          stop("Test result doesn't exist for multisnpnet rank ", r, ".\n")
        }
        metric_test <- myenv[[test_name]]
      }
      if ((train_bin_name %in% names(myenv)) && (val_bin_name %in% names(myenv))) {
        AUC_train <- myenv[[train_bin_name]]
        AUC_val <- myenv[[val_bin_name]]
        if (!is.null(test_name)) AUC_test <- myenv[[test_bin_name]]
        bin_names <- unique(c(bin_names, colnames(AUC_train)))
        metric_train[, colnames(AUC_train)] <- AUC_train
        metric_val[, colnames(AUC_val)] <- AUC_val
        if (!is.null(test_name)) metric_test[, colnames(AUC_test)] <- AUC_test
      }
      imax_train <- max(which(apply(metric_train, 1, function(x) sum(is.na(x))) == 0))
      imax_val <- max(which(apply(metric_val, 1, function(x) sum(is.na(x))) == 0))
      imax <- min(imax_train, imax_val)
      if (!is.null(test_name)) {
        imax_test <- max(which(apply(metric_test, 1, function(x) sum(is.na(x))) == 0))
        imax <- min(imax, imax_test)
        metric_test <- metric_test[1:imax, , drop = F]
        metric_test <- cbind(metric_test, lambda = 1:imax)
        table_test <- reshape2::melt(as.data.frame(metric_test), id.vars = "lambda", variable.name = "phenotype", value.name = "metric_test")
      }
      metric_train <- metric_train[1:imax, , drop = F]
      metric_val <- metric_val[1:imax, , drop = F]
      metric_train <- cbind(metric_train, lambda = 1:imax)
      metric_val <- cbind(metric_val, lambda = 1:imax)

      table_train <- reshape2::melt(as.data.frame(metric_train), id.vars = "lambda", variable.name = "phenotype", value.name = "metric_train")
      table_val <- reshape2::melt(as.data.frame(metric_val), id.vars = "lambda", variable.name = "phenotype", value.name = "metric_val")
      data_metric <- dplyr::inner_join(table_train, table_val, by = c("phenotype", "lambda"))
      if (!is.null(test_name)) {
        data_metric <- dplyr::inner_join(data_metric, table_test, by = c("phenotype", "lambda"))
      }
      data_metric[["type"]] <- type[dir_idx]
      data_metric[["rank"]] <- factor(r, levels = as.character(rank))

      data_metric_full <- rbind(data_metric_full, data_metric)
    }
  }

  if (!is.null(snpnet_dir)) {
    for (phe in as.character(unique(data_metric_full[["phenotype"]]))) {
      print(phe)
      phe_dir <- file.path(snpnet_dir, phe, snpnet_subdir)  # results/results
      files_in_dir <- list.files(phe_dir)
      result_files <- files_in_dir[startsWith(files_in_dir, snpnet_prefix) & endsWith(files_in_dir, snpnet_suffix)]
      max_iter <- max(as.numeric(gsub(snpnet_suffix, "", gsub(pattern = snpnet_prefix, "", result_files))))
      latest_result <- file.path(phe_dir, paste0(snpnet_prefix, max_iter, snpnet_suffix))

      myenv <- new.env()
      load(latest_result, envir = myenv)
      metric_train <- myenv[["metric.train"]]
      metric_val <- myenv[["metric.val"]]
      if (!is.null(test_name)) {
        if (!("metric.test" %in% names(myenv))) {
          stop("Test result doesn't exist for ", phe, ".\n")
        }
        metric_test <- myenv[["metric.test"]]
        imax_test <- max(which(!is.na(metric_test)))
      }

      imax_train <- max(which(!is.na(metric_train)))
      imax_val <- max(which(!is.na(metric_val)))
      imax <- min(imax_train, imax_val)
      if (!is.null(test_name)) imax <- min(imax, imax_test)

      table_snpnet <- data.frame(lambda = 1:imax, phenotype = rep(phe, imax), metric_train = metric_train[1:imax],
                                 metric_val = metric_val[1:imax], type = "exact", rank = "snpnet")
      if (!is.null(test_name)) table_snpnet[["metric_test"]] <- metric_test[1:imax]

      data_metric_full <- rbind(data_metric_full, table_snpnet)
    }
  }

  if (!is.null(mapping_phenotype)) {
    data_metric_full$phenotype <- as.character(data_metric_full$phenotype)
    for (phe in names(mapping_phenotype)) {
      data_metric_full$phenotype[data_metric_full$phenotype == phe] <- mapping_phenotype[phe]
    }
    reverse_mapping <- names(mapping_phenotype)
    names(reverse_mapping) <- mapping_phenotype
  }

  gp <- list(data = data_metric_full)

  if (!is.null(snpnet_dir)) {
    max_metric_reduced_rank <- data_metric_full %>%
      dplyr::filter(rank != "snpnet") %>%
      dplyr::group_by(phenotype) %>%
      dplyr::filter(metric_val == max(metric_val)) %>%
      dplyr::mutate(max_multisnpnet_val = ifelse(!is.null(test_name), metric_test, metric_val)) %>%
      dplyr::select(phenotype, max_multisnpnet_val, rank, lambda)
    max_metric_snpnet <- data_metric_full %>%
      dplyr::filter(rank == "snpnet") %>%
      dplyr::group_by(phenotype) %>%
      dplyr::filter(metric_val == max(metric_val)) %>%
      dplyr::mutate(max_snpnet_val = ifelse(!is.null(test_name), metric_test, metric_val)) %>%
      dplyr::select(phenotype, max_snpnet_val)
    max_metric <- max_metric_reduced_rank %>%
      dplyr::inner_join(max_metric_snpnet, by = "phenotype") %>%
      dplyr::mutate(absolute_change = max_multisnpnet_val - max_snpnet_val,
                    relative_change = max_multisnpnet_val/abs(max_snpnet_val)-1,
                    direction = ifelse(relative_change > 0, "P", "N"))
    val_test_label <- ifelse(!is.null(test_name), "Test ", NULL)
    gp[["max_metric"]] <- max_metric
    gp[["metric_cmp_abs_change"]] <- ggplot(max_metric, aes(x = phenotype, y = absolute_change)) +
      geom_bar(stat = "identity", position = "dodge", aes(fill = direction)) +
      geom_hline(yintercept = 0, colour = "grey90") +
      theme(axis.text.x = element_text(angle = 45, hjust = 1), legend.position = "none") +
      xlab("Phenotype") + ylab(paste0(val_test_label, "Metric Absolute Change"))
    if (!is.null(save_dir)) {
      save_path <- file.path(save_dir, "metric_cmp_abs_change.pdf")
      ggsave(save_path, plot = gp[["metric_cmp_abs_change"]])
    }
    gp[["metric_cmp_rel_change"]] <- ggplot(max_metric, aes(x = phenotype, y = relative_change*100)) +
      geom_bar(stat = "identity", position = "dodge", aes(fill = direction)) +
      geom_hline(yintercept = 0, colour = "grey90") +
      theme(axis.text.x = element_text(angle = 45, hjust = 1), legend.position = "none") +
      xlab("Phenotype") + ylab(paste0(val_test_label, "Metric Relative Change (%)"))
    if (!is.null(save_dir)) {
      save_path <- file.path(save_dir, "metric_cmp_rel_change.pdf")
      ggsave(save_path, plot = gp[["metric_cmp_rel_change"]])
    }
    # relative plot with absoluate value on the second y axis
    abs_range <- range(max_metric$absolute_change)
    rel_range <- range(max_metric$relative_change)
    if (abs_range[1] * abs_range[2] < 0) {
      multiplier <- min(rel_range[2] / abs_range[2], rel_range[1] / abs_range[1]) * 100
    } else {
      multiplier <- max(abs(rel_range)) / max(abs(abs_range)) * 100 / 2
    }
    gp[["metric_cmp_abs_rel_change"]] <- ggplot(max_metric, aes(x = reorder(phenotype, -relative_change), y = relative_change*100)) +
      geom_bar(stat = "identity", position = "dodge", aes(fill = direction)) +
      geom_hline(yintercept = 0, colour = "grey90") +
      geom_point(aes(y = absolute_change * multiplier), size = 1.5) +
      theme(axis.text.x = element_text(angle = 45, hjust = 1), legend.position = "none") +
      scale_y_continuous(sec.axis = sec_axis(~. * (1.0/multiplier), name = paste0(val_test_label, "Metric Absolute Change"))) +
      xlab("Phenotype") + ylab(paste0(val_test_label, "Metric Relative Change (%)"))
    if (!is.null(save_dir)) {
      save_path <- file.path(save_dir, "metric_cmp_abs_rel_change.pdf")
      ggsave(save_path, plot = gp[["metric_cmp_abs_rel_change"]])
    }
  }

  for (phe in as.character(unique(data_metric_full[["phenotype"]]))) {
    fname_phe <- ifelse(!is.null(mapping_phenotype) && (phe %in% mapping_phenotype), reverse_mapping[phe], phe)
    mname <- ifelse(fname_phe %in% bin_names, metric_bin_name, metric_name)
    gp[[fname_phe]] <- ggplot(dplyr::filter(data_metric_full, phenotype == phe), aes(x = metric_train, y = metric_val, shape = type, colour = rank)) +
      geom_path() + geom_point() +
      xlab(paste(mname, "(train)")) + ylab(paste(mname, "(val)")) +
      xlim(as.numeric(xlim)) + ylim(as.numeric(ylim)) +
      theme(axis.text=element_text(size=12), axis.title=element_text(size=12),
            legend.text=element_text(size=12), legend.title = element_text(size=12),
            legend.position = "bottom",
            strip.text.x = element_text(size = 12), strip.text.y = element_text(size = 12)) +
      ggtitle(phe)
    if (!is.null(save_dir)) {
      save_path <- file.path(save_dir, paste0(mname, "_plot_", fname_phe, ".pdf"))
      ggsave(save_path, plot = gp[[fname_phe]])
    }
  }

  gp
}

safe_product <- function(X, Y, MAXLEN = (2^31 - 1) / 2, use_safe = TRUE) {
  if (use_safe) {
    ncol.chunk <- floor(MAXLEN / as.double(nrow(X)))  # depends on the memory requirements
    numChunks <- ceiling(ncol(X) / as.double(ncol.chunk))
    out <- matrix(0, nrow(X), ncol(Y))
    rownames(out) <- rownames(X)
    colnames(out) <- colnames(Y)
    for (jc in seq_len(numChunks)) {
      idx <- ((jc-1)*ncol.chunk+1):min(jc*ncol.chunk, ncol(X))
      out <- out + X[, idx, drop=FALSE] %*% Y[idx, , drop=FALSE]
    }
  } else {
    out <- X %*% Y
  }
  out
}

#' weighted row means
#'
#' @param M matrix
#' @param w weights as a named list
#'
weighted_rowMeans <- function(M, w){
  w <- w / sum(w)
  as.matrix(M[, names(w)]) %*% as.matrix(w)
}


#' Check if the early stopping condition is satisfied
#'
#' For the specified traits, we check if the validation set metric (and validation set AUC) has been decreasing in the last two steps. If also check if the average of the validation set metric and the validation set AUC has been decreasing in the last two steps. If no traits were provided, we check it for all the traits considered in SRRR.
#'
#' @param ilam Lambda index
#' @param metric_val Validation set metric
#' @param AUC_val Validation set AUC
#' @param traits (optional) subset of traits
#' @param weight weights (named list) for the weighted average computation.
#' @param check_average whether to check the average of the metrics
#' @param stopping.lag how many iterations shall we wait after the argmax until we stop
#'
#' @export
check_early_stopping_condition <- function(ilam, metric_val, AUC_val = NULL, traits = NULL, weight = NULL, check_average = TRUE, stopping.lag = 2){
  if(stopping.lag < 0) stopping.lag <- 0
  if(ilam <= stopping.lag){
    return(FALSE)
  }
  if(is.null(traits)){
    traits <- colnames(metric_val)
  }
  if(is.null(weight)){
    weight <- setNames(rep(1, ncol(metric_val)), colnames(metric_val))
  }
  check_AUC <- ((! is.null(AUC_val)) && ncol(AUC_val) > 0)

  metric_val_traits <- as.matrix(metric_val[, traits])
  max_argmax_metric_val_traits <- max(apply(metric_val_traits, 2, which.max))
  argmax_mean_metric_val <- which.max(weighted_rowMeans(metric_val, weight))

  if(check_AUC){
    AUC_val_traits <- as.matrix(AUC_val[, traits])
    max_argmax_AUC_val_traits <- max(apply(AUC_val_traits, 2, which.max))
    argmax_mean_AUC_val <- which.max(weighted_rowMeans(AUC_val, weight[colnames(AUC_val)]))
    if(check_average){
      return(
        argmax_mean_metric_val       + stopping.lag <= ilam &&
        argmax_mean_AUC_val          + stopping.lag <= ilam &&
        max_argmax_metric_val_traits + stopping.lag <= ilam &&
        max_argmax_AUC_val_traits    + stopping.lag <= ilam
      )
    }else{
      return(
        max_argmax_metric_val_traits + stopping.lag <= ilam &&
        max_argmax_AUC_val_traits    + stopping.lag <= ilam
      )
    }
  }else{
    if(check_average){
      return(
        argmax_mean_metric_val       + stopping.lag <= ilam &&
        max_argmax_metric_val_traits + stopping.lag <= ilam
      )
    }else{
      return(
        max_argmax_metric_val_traits + stopping.lag <= ilam
      )
    }
  }
}


#' Get the path of the R data file
#'
#' @param results_dir The results directory.
#' @param idx Lambda index
#'
#' @export
get_rdata_path <- function(results_dir, idx){
  return(file.path(results_dir, paste0("output_lambda_", idx, ".RData")))
}


#' Get the index of the previous iteration in a specified results directory
#'
#' @param results_dir The results directory.
#' @param nlambda The maximum number of lambda
#'
#' @export
find_prev_iter <- function(results_dir, nlambda = 100){
  prev_iter <- 0
  for (idx in 1:nlambda) {
    if (file.exists(get_rdata_path(results_dir, idx))) prev_iter <- idx
  }
  return(prev_iter)
}


#' Given a matrix, return non-NA lines
#'
#' @param M matrix
#'
get_non_NA_lines <- function(M){
  M[apply(M, 1, function(x){all(! is.na(x))}), ]
}


#' Given a matrix, return lines with non-zero entries
#'
#' @param M matrix
#'
get_non_zero_lines <- function(M){
  M[apply(M, 1, function(x){! all(x == 0)}), ]
}


#' Given a results directory and lambda index (optional), load the corresponding R Data file
#'
#' @param results_dir The results directory
#' @param lambda_idx The lambda index. If not specified, we call find_prev_iter() and load the last lambda index available on the file system.
#'
#' @export
#'
load_multiSnpnetRdata <- function(results_dir, lambda_idx = NULL){
  if(is.null(lambda_idx)){
    lambda_idx <- find_prev_iter(results_dir, nlambda = 200)
  }
  # load the data at the specified lambda idx
  e_lambda_idx <- new.env()
  message(sprintf('loading lambda_idx = %d in %s', lambda_idx, results_dir))
  load(get_rdata_path(results_dir, lambda_idx), envir = e_lambda_idx)
  return(e_lambda_idx)
}


#' Check if the training (or validation) set metric exists in the multiSnpnetResults object (list)
#'
#' @param multiSnpnetResults a list containing the results of the multiSnpnet fit
#'
#' @return A boolean value
#' @examples
#' check_if_metric_exists(fit, 'metric_train')
#' check_if_metric_exists(fit, 'metric_val')
#' check_if_metric_exists(fit, 'AUC_train')
#' check_if_metric_exists(fit, 'AUC_val')
#'
#' @export
#'
check_if_metric_exists <- function(multiSnpnetResults, metric_name){
  return(
    (metric_name %in% names(multiSnpnetResults)) &&
    (all(is.numeric(multiSnpnetResults[[metric_name]])))
  )
}


#' Extract metrics
#'
#' @param multiSnpnetResults a list containing the results of the multiSnpnet fit
#' @param metric_name The types of metrics (optional)
#'
#' @return A matrix of metrics (one of the followings: AUC_val, metric_val, AUC_train, and metric_train)
#'
extract_metrics <- function(multiSnpnetResults, metric_name = NULL){
  if(is.null(metric_name)){
    if(       check_if_metric_exists(multiSnpnetResults, 'AUC_val') ){
      metric_name <- 'AUC_val'
    }else if( check_if_metric_exists(multiSnpnetResults, 'metric_val') ){
      metric_name <- 'metric_val'
    }else if( check_if_metric_exists(multiSnpnetResults, 'AUC_train') ){
      metric_name <- 'AUC_train'
    }else{
      metric_name <- 'metric_train'
    }
    message(sprintf('metric: %s', metric_name))
  }
  return(multiSnpnetResults[[metric_name]])
}


#' Extract weights
#'
#' @param multiSnpnetResults a list containing the results of the multiSnpnet fit
#'
#' @return A list of weights
#'
extract_weight <- function(multiSnpnetResults){
    # get trait weights
    if('weight' %in% names(multiSnpnetResults)){
      weight <- multiSnpnetResults[['weight']]
    }else if(
      ('configs' %in% names(multiSnpnetResults)) &&
      ('weight' %in% names(multiSnpnetResults[['configs']]))
    ){
      weight <- multiSnpnetResults[['configs']][['weight']]
    }else{
      weight <- NULL
    }
    return(weight)
}


#' Extract weighted metrics
#'
#' @param multiSnpnetResults a list containing the results of the multiSnpnet fit
#' @param metric_name The types of metrics (optional)
#' @param weight The trait weights (optional)
#'
#' @return A matrix of metrics (one of the followings: AUC_val, metric_val, AUC_train, and metric_train)
#'
extract_weighted_metrics <- function(multiSnpnetResults, metric_name = NULL, weight = NULL){
  metric_mat <- extract_metrics(multiSnpnetResults, metric_name)
  if(is.null(weight)){
    # get trait weights
    weight <- extract_weight(multiSnpnetResults)
    if(is.null(weight)){
      message("Using uniform weights")
      weight <- rep(1, ncol(metric_mat))
    }
  }
  weight <- weight[colnames(metric_mat)]
  weight <- weight / sum(weight)
  for(col_idx in seq_along(ncol(metric_mat))){
    metric_mat[, col_idx] <- metric_mat[, col_idx] * weight[col_idx]
  }
  return(metric_mat)
}


#' Extract all weighted and unweighted metrics
#'
#' @param multiSnpnetResults a list containing the results of the multiSnpnet fit
#'
#' @return A dataframe of all available metrics
#'
#' @export
extract_all_metrics <- function(multiSnpnetResults){
  metric_names <- c('metric_train', 'metric_val', 'AUC_train', 'AUC_val')
  metric_names <- metric_names[
      sapply(metric_names, function(mn){check_if_metric_exists(multiSnpnetResults, mn)})
  ]
  lapply(metric_names, function(metric_n){
    bind_rows(
      multiSnpnetResults %>%
      extract_metrics(metric_n) %>%
      data.frame() %>%
      mutate(metric_name = metric_n, weighted = F) %>%
      rownames_to_column('lambda_idx'),
      multiSnpnetResults %>%
      extract_weighted_metrics(metric_n) %>%
      data.frame() %>%
      mutate(metric_name = metric_n, weighted = T) %>%
      rownames_to_column('lambda_idx')
    )
  }) %>%
  bind_rows %>%
  mutate(lambda_idx = as.integer(lambda_idx)) %>%
  gather(trait, metric, -lambda_idx, -metric_name, -weighted)
}


#' Select the "best" lambda index given the training (and validation) set metrics
#'
#' If the "lambda_idx" attribute in the given object is already available, simply return that value. Otherwise, we look at the performance metric and decide the "best" lambda index.
#' If there is no validation set metric available, we simply return the last lambda index in the training set metric.
#' When we have access to validation set metrics, we identify the lambda index that maximizes the (weighted) average of validation set metrics.
#'
#' @param multiSnpnetResults a list containing the results of the multiSnpnet fit
#' @param metric_name the name of the metric (metric_val or AUC_val)
#' @param use_weight whether we should use trait weights when evaluating weighted average of the metric
#' @param traits (optional) subset of traits
#' @param force if TRUE, we recompute the best lambda index
#'
#' @return An integer denoting the best lambda index. If the validation set is available (by running check_if_metric_exists()), we return the lambda index that maximizes the validation set metric. If the AUC_val is availale, we use AUC_val instead of metric_val. We take the weighted average of the metric.
#'
find_best_lambda_index <- function(multiSnpnetResults, metric_name = NULL, use_weight = TRUE, traits = NULL, force = FALSE){
  if((!force) && ('lambda_idx' %in% names(multiSnpnetResults))){
    lambda_idx <- multiSnpnetResults[['lambda_idx']]
  }else if(! check_if_metric_exists(multiSnpnetResults, 'metric_val')){
    # validation set metric is not available, meaning we only have the training set
    lambda_idx <- nrow(
      get_non_NA_lines(multiSnpnetResults[['metric_train']])
    )
  }else{
    # select the lambda index by taking the argmax
    if(use_weight){
      metric_mat <- extract_weighted_metrics(
        multiSnpnetResults, metric_name
      )
    }else{
      metric_mat <- extract_metrics(
        multiSnpnetResults, metric_name
      )
    }
    if(is.null(traits)){
      traits <- colnames(metric_mat)
    }
    lambda_idx <- which.max(rowMeans(as.matrix(metric_mat[, traits])))
  }
  return(lambda_idx)
}

#' drop feature statistics and individual-level data
drop_bulky_stuff_in_fit_object <- function(fit_obj, drop_list = c("std_obj", "response", "residuals", "stats")){
  for(drop_name in drop_list){
    if(drop_name %in% names(fit_obj)){
      fit_obj[[drop_name]] <- NULL
    }
  }
  return(fit_obj)
}


#' Set the results of multiSnpnet run in a list
#'
#' @param fit_list fit_list that contains a series of fit across the lambda sequence
#' @param ilam the current lambda index
#' @param metric_train a matrix containing the metric in the training set
#' @param metric_val a matrix containing the metric in the validation set
#' @param AUC_train a matrix containing the AUC in the training set
#' @param AUC_val a matrix containing the AUC in the validation set
#' @param configs a config object (named list)
#'
#' @return An list object containing the training (and validation) set metrics, lambda_idx, fit, and configs. We will extract the results from the "best" lambda index for fit and configs. The configs itself is a list containing important paramters such as weight (trait weight).
#'
prepare_multiSnpnetResults <- function(fit_list, ilam, metric_train, metric_val, AUC_train, AUC_val, configs){
  # copy the relevant metrics
  multiSnpnetResults <- list()
  multiSnpnetResults[["configs"]] <- configs
  if((!is.null(metric_train)) && (all(is.numeric(metric_train))))
    multiSnpnetResults[["metric_train"]] <- metric_train
  if((!is.null(metric_val))   && (all(is.numeric(metric_val))))
    multiSnpnetResults[["metric_val"]] <- metric_val
  if((!is.null(AUC_train)) && (all(is.numeric(AUC_train))))
    multiSnpnetResults[["AUC_train"]] <- AUC_train
  if((!is.null(AUC_val)) && (all(is.numeric(AUC_val))))
    multiSnpnetResults[["AUC_val"]] <- AUC_val
  # select the best lambda index
  lambda_idx <- find_best_lambda_index(multiSnpnetResults)
  multiSnpnetResults[["lambda_idx"]] <- lambda_idx
  # copy the fit object from the best lambda index
  multiSnpnetResults[["fit"]] <- drop_bulky_stuff_in_fit_object(fit_list[[lambda_idx]])

  # add the full fit sequence
  multiSnpnetResults[["fit_list"]] <- vector("list", ilam)
  for(i in seq(ilam-1)){
    # drop feature statistics and individual-level data for all but the last index
    multiSnpnetResults[["fit_list"]][[i]] <- drop_bulky_stuff_in_fit_object(fit_list[[i]])
  }
  multiSnpnetResults[["fit_list"]][[ilam]] <- fit_list[[ilam]]
  class(multiSnpnetResults) <- "multiSnpnetResults"
  return(multiSnpnetResults)
}


#' Load the results of multiSnpnet
#'
#' @param results_dir The results directory
#' @param last_lambda_idx The last lambda index. If not specified, we call find_prev_iter() and load the last lambda index available on the file system.
#'
#' @return An list object containing the training (and validation) set metrics, lambda_idx, fit, and configs. We will extract the results from the "best" lambda index for fit and configs. The configs itself is a list containing important paramters such as weight (trait weight).
#'
#' @export
#'
load_multiSnpnetResultsFromRDataFiles <- function(results_dir, last_lambda_idx = NULL, fill_sequence = FALSE){
  if(is.null(last_lambda_idx)){
    last_lambda_idx <- find_prev_iter(results_dir, nlambda = 200)
  }
  # load the data at the last lambda idx
  e_last_lambda_idx <- load_multiSnpnetRdata(results_dir, last_lambda_idx)

  # copy the relevant metrics
  multiSnpnetResults <- list()
  for(metric_name in c('metric_train', 'metric_val', 'AUC_train', 'AUC_val')){
    if(check_if_metric_exists(e_last_lambda_idx, metric_name)){
      multiSnpnetResults[[metric_name]] <- as.matrix(e_last_lambda_idx[[metric_name]][1:last_lambda_idx, ])
      colnames(multiSnpnetResults[[metric_name]]) <- colnames(e_last_lambda_idx[[metric_name]])
    }
  }
  for(obj_name in c('configs')){
    if(obj_name %in% names(e_last_lambda_idx)){
      multiSnpnetResults[[obj_name]] <- e_last_lambda_idx[[obj_name]]
    }
  }
  lambda_idx <- find_best_lambda_index(multiSnpnetResults)
  multiSnpnetResults[['lambda_idx']] <- lambda_idx

  if(lambda_idx == last_lambda_idx){
    e_lambda_idx <- e_last_lambda_idx
  }else{
    e_lambda_idx <- load_multiSnpnetRdata(results_dir, lambda_idx)
  }
  e_lambda_idx[['fit']] <- drop_bulky_stuff_in_fit_object(e_lambda_idx[['fit']])

  for(obj_name in c('fit', 'configs')){
    if(obj_name %in% names(e_lambda_idx)){
      multiSnpnetResults[[obj_name]] <- e_lambda_idx[[obj_name]]
    }
  }
  if(fill_sequence){
    multiSnpnetResults[["fit_list"]] <- vector("list", last_lambda_idx)
    for(idx in seq(last_lambda_idx)){
      e_idx <- load_multiSnpnetRdata(results_dir, idx)
      if(idx < last_lambda_idx){
        multiSnpnetResults[["fit_list"]][[idx]] <- drop_bulky_stuff_in_fit_object(e_idx[["fit"]])
      }else{
        multiSnpnetResults[["fit_list"]][[idx]] <- e_idx[["fit"]]
      }
    }
  }
  class(multiSnpnetResults) <- "multiSnpnetResults"
  return(multiSnpnetResults)
}


#' Read the lambda sequence from the initial fit with a validation set
#'
#' For the path to a multiSnpnet results directory (saved from a multiSnpnet run with a validation set),
#' identify the last "previous" iteration, identify the best lambda based on the validation set AUC,
#' and load the lambda sequence stored in the RData file and return the lambda sequence.
#' This is useful for performing a "refit" using a combined set of training + validation set.
#'
#' @param results_dir The results directory
#' @param lambda_idx The lambda index. If not specified, we call find_best_lambda_index()
#'
#' @return A lambda sequence
#'
#' @export
#'
read_lambda_sequence_from_multiSnpnetResults <- function(multiSnpnetResults, lambda_idx = NULL){
  if(is.null(lambda_idx)){
    lambda_idx <- find_best_lambda_index(multiSnpnetResults)
  }
  multiSnpnetResults[['configs']][['lambda']][1:lambda_idx]
}

#' Tabulate the best lambda index
#'
#' For the path to a multiSnpnet results directory (saved from a multiSnpnet run with a validation set),
#' identify the last "previous" iteration, identify the best lambda based on the validation set AUC,
#' and load the lambda sequence stored in the RData file and return the lambda sequence.
#' This is useful for performing a "refit" using a combined set of training + validation set.
#'
#' @param multiSnpnetResults a list containing the results of the multiSnpnet fit
#' @param metrics a list of metrics to include
#' @param lambda_idx The lambda index. If not specified, we call find_best_lambda_index()
#'
#' @return A data frame of one row containing the lambda index (best and last), number of non-zero variables in the model, and the performance metrics at the best lambda index
#'
#' @export
#'
get_summary_metrics <- function(
    multiSnpnetResults, metrics = c('metric_train', 'metric_val', 'AUC_train', 'AUC_val'), lambda_idx = NULL
){
  if(is.null(lambda_idx)){
    lambda_idx <- find_best_lambda_index(multiSnpnetResults)
  }
  metrics_list <- setNames(
    lapply(metrics, function(metric_name){
      if(check_if_metric_exists(multiSnpnetResults, metric_name)){
        rowMeans(
          extract_weighted_metrics(multiSnpnetResults, metric_name)
        )[lambda_idx]
      }else{
        NA
      }
    }),
    metrics
  )
  return(bind_cols(
    mutate(
      data.frame(
        best_lambda_idx = lambda_idx,
        last_lambda_idx = nrow(multiSnpnetResults[['metric_train']]),
        n_variables = nrow(get_non_zero_coefficients(multiSnpnetResults[['fit']]))
      ),
      across(everything(), as.integer)
    ),
    mutate(
      as.data.frame(metrics_list),
      across(everything(), as.numeric)
    )
  ))
}


#' Extract the non-zero coefficients from the fit object
#'
#' Extract the coefficients (C) from the fit object where the coefficient has at least one non-zero entry across the response variables
#'
#' @param fit_obj A named list containing the results of the multisnpnet results.
#'
#' @export
get_non_zero_coefficients <- function(fit_obj){
  return(as.matrix(get_non_zero_lines(fit_obj$C)))
}


#' Given a data frame with ID_ALT column, this function separates them into two columns ID and ALT
#'
#' @param ID_ALT_df A dataframe with ID_ALT column
#'
#' @export
separate_ID_ALT_into_two_columns <- function(ID_ALT_df){
  select(
    mutate(
      separate(
        ID_ALT_df,
        "ID_ALT",
        c("ID_ALT1", "ID_ALT2", "ALT"),
        sep = "_",
        extra = "merge",
        fill = "left",
        remove = T
      ),
      ID = if_else(is.na(ID_ALT1), ID_ALT2, paste(ID_ALT1, ID_ALT2, sep="_"))
    ), -ID_ALT1, -ID_ALT2
  )
}


#' Extract the variant loadings of the SVD of coefficients as a data frame
#'
#' @param svd_obj A named list containing the results of the svd.
#' @param component_prefix the prefix of the SVD loadings column
#' @param rank the rank of SVD
#'
#' @export
#' @examples
#' tsvd_of_C_variant_loadings_as_data_framte(tsvd_of_C_with_names(fit_obj, 'Component', rank=10))
tsvd_of_C_variant_loadings_as_data_framte <- function(svd_obj){
  separate_ID_ALT_into_two_columns(
    rownames_to_column(
      as.data.frame(svd_obj$v),
      "ID_ALT"
    )
  ) -> beta_ID_ALT_df
  return(select(
    beta_ID_ALT_df,
    all_of(c(
      "ID", "ALT",
      setdiff(colnames(beta_ID_ALT_df), c("ID", "ALT"))
    ))
  ))
}


#' Extract the non-zero coefficients from the fit object and return it as a data frame
#'
#' Extract the coefficients (C) from the fit object where the coefficient has at least one non-zero entry across the response variables. The function also reads the pvar file so that the resulting data frame has the chromosomal position of the genetic variants.
#'
#' @param multiSnpnetResults a list containing the results of the multiSnpnet fit
#' @param fit_obj A named list containing the results of the multisnpnet results.
#' @param genotype_file Path to the new suite of genotype files. genotype_file.pvar.zst must exist.
#' @param zstdcat_path Path to zstdcat program, needed when loading variants.
#'
#' @export
get_non_zero_coefficients_as_data_frame <- function(multiSnpnetResults = NULL, fit_obj = NULL, genotype_file = NULL, zstdcat_path = 'zstdcat'){
  if(!is.null(multiSnpnetResults)){
    if(is.null(fit_obj)){
      fit_obj <- multiSnpnetResults[['fit']]
    }
    if(is.null(genotype_file)){
      genotype_file <- multiSnpnetResults[['configs']][['genotype_file']]
    }
  }
  coeff_df <- inner_join(
    read_pvar(genotype_file, zstdcat_path),
    separate_ID_ALT_into_two_columns(
      rownames_to_column(
        as.data.frame(get_non_zero_coefficients(fit_obj)),
        "ID_ALT"
      )
    ), by = c("ID", "ALT")
  )
  return(coeff_df)
}


#' Compute the TSVD of the regression coefficient
#'
#' Compute the TSVD of the regression coefficient C and set colnames and rownames in the decomposed matrices.
#'
#' @param fit_obj A named list containing the results of the multisnpnet results.
#' @param component_prefix A string used as a prefix for the column names corresponding to the latent variables
#' @param rank Desired rank of the decomposed matrices
#'
#' @export
tsvd_of_C_with_names <- function(fit_obj, component_prefix="Component", rank=NULL){
  non_zero_C <- get_non_zero_coefficients(fit_obj)
  if( (is.null(rank)) || (rank > min(dim(non_zero_C))) ){
      rank <- min(dim(non_zero_C))
  }
  svd_of_C <- svd(t(as.matrix(non_zero_C)), nu = rank, nv = rank)
  svd_of_C$d <- svd_of_C$d[1:rank]
  svd_of_C$d <- setNames(svd_of_C$d, paste0(component_prefix, 1:length(svd_of_C$d)))
  colnames(svd_of_C$v) <- names(svd_of_C$d)
  colnames(svd_of_C$u) <- names(svd_of_C$d)
  rownames(svd_of_C$v) <- rownames(non_zero_C)
  rownames(svd_of_C$u) <- colnames(non_zero_C)
  return(svd_of_C)
}


#' Compute the contribution scores
#'
#' Compute the relative importance traits (or variants) for each component as defined in Tanigawa et al Nat Comm 2019.
#'
#' @param svd_obj A named list containing three matrices with u, d, and v as their names as in the
#'   output from base::svd() function. One can pass the results of base::svd(t(fit$C)) or tsvd_of_C_with_names(fit).
#'   Please note that this function assumes svd_obj$u and svd_obj$v corresponds to phenotypes and variants, respectively.
#' @param right_singular_vectors An indicator variable to specify if we compute the score for right singular vector or not. If true, we compute the variant squared cosine score. If false, we compute phenotype squared cosine score.
#'
#' @importFrom magrittr '%>%'
#' @importFrom tibble rownames_to_column
#' @importFrom tidyr gather
#'
#' @export
score_contribution <- function(svd_obj, right_singular_vectors=FALSE){
  if(right_singular_vectors){
    singular_vectors <- svd_obj$v
  }else{
      singular_vectors <- svd_obj$u
  }
  (singular_vectors ** 2) %>%
  as.data.frame() %>% rownames_to_column() %>%
  gather(component, contribution_score, -rowname)
}


#' Compute the squared cosine scores
#'
#' Compute the relative importance of components for each trait (or variant) as defined in Tanigawa et al Nat Comm 2019.
#'
#' @param svd_obj A named list containing three matrices with u, d, and v as their names as in the
#'   output from base::svd() function. One can pass the results of base::svd(t(fit$C)) or tsvd_of_C_with_names(fit).
#'   Please note that this function assumes svd_obj$u and svd_obj$v corresponds to phenotypes and variants, respectively.
#' @param right_singular_vectors An indicator variable to specify if we compute the score for right singular vector or not. If true, we compute the variant squared cosine score. If false, we compute phenotype squared cosine score.
#' @param component_prefix A string used as a prefix for the column names corresponding to the latent variables
#'
#' @importFrom magrittr '%>%'
#' @importFrom tibble rownames_to_column
#' @importFrom tidyr gather
#' @importFrom dplyr mutate
#'
#' @export
score_squared_cosine <- function(svd_obj, right_singular_vectors=FALSE, component_prefix='Component'){
  if(right_singular_vectors){
    singular_vectors <- svd_obj$v
  }else{
      singular_vectors <- svd_obj$u
  }

  ((((singular_vectors) %*% diag(svd_obj$d)) ** 2) / rowSums(((singular_vectors) %*% diag(svd_obj$d)) ** 2)) %>%
  as.data.frame() %>% rownames_to_column() %>%
  gather(component, squared_cosine_score, -rowname) %>%
  mutate(component = str_replace(component, '^V', component_prefix))
}


#' Generate color palette
#'
#' @param key An optional argument to select a specific color in palette.
#'
get_cb_colors <- function(key=NULL){
  cb.colors <- c(
    black = "#000000",
    orange = "#E69F00",
    sky.blue = "#56B4E9",
    bluish.green = "#009E73",
    yellow = "#F0E442",
    blue = "#0072B2",
    vermilion = "#D55E00",
    reddish.purple = "#CC79A7"
  )
  if(is.null(key) || (! key %in% names(cb.colors))){
    cb.colors
  }else{
    cb.colors[[ key ]]
  }
}


#' Generate a scree plot visualization based on the decomposed coefficient matrix C.
#'
#' @param svd_obj A named list containing three matrices with u, d, and v as their names as in the
#'   output from base::svd() function. One can pass the results of base::svd(t(fit$C)) or tsvd_of_C_with_names(fit)
#'
#' @import ggplot2
#' @importFrom magrittr '%>%'
#' @importFrom dplyr mutate
#' @importFrom tibble enframe
#' @examples
#' plot_scree( tsvd_of_C_with_names(fit, rank = 10) )
#'
#' @export
plot_scree <- function(svd_obj) {
  svd_obj$d %>%
  enframe %>%
  mutate(
    Component = as.integer(str_replace(name, "Component", "")),
    variance_explained = value ** 2 / sum(value ** 2),
  ) %>%
  ggplot(aes(
    x = as.factor(Component),
    y = variance_explained
  )) +
  geom_point() +
  theme_bw(base_size = 16) +
  labs(
    title = 'Scree plot',
    x = 'Component',
    y = 'Relative variance explained'
  )
}


#' Make biplots of the multisnpnet results
#'
#' Generate biplot visualization based on the decomposed coefficient matrix C.
#'
#' @param svd_obj A named list containing three matrices with u, d, and v as their names as in the
#'   output from base::svd() function. One can pass the results of base::svd(t(fit$C)) or tsvd_of_C_with_names(fit)
#'   Please note that this function assumes svd_obj$u and svd_obj$v corresponds to phenotypes and variants, respectively.
#' @param component A named list that specifies the index of the components used in the plot.
#' @param label A named list that specifies the phenotype and variant labels.
#'   The labels needs to be the same order as in svd_obj$u and svd_obj$v.
#' @param n_labels A named list that specifies the number of phenotype and variant labels in the plot.
#' @param color A named list that specifies the color in the plot.
#' @param shape A named list that specifies the color in the plot.
#' @param axis_label A named list that specifies the names used in the axis labels.
#' @param use_ggrepel A binary variable that specifies whether we should use ggrepel to annotate the
#'   labels of the data points.
#'
#' @import ggplot2
#' @importFrom magrittr '%>%'
#' @importFrom tibble rownames_to_column
#' @importFrom dplyr rename select mutate if_else bind_rows
#'
#' @examples
#' plot_biplot(svd(t(fit$C)), label=list('phenotype'=rownames(A_init), 'variant'=rownames(fit$C)))
#'
#' @export
plot_biplot <- function(svd_obj, component=list('x'=1, 'y'=2),
                        label=list('phenotype'=NULL, 'variant'=NULL),
                        n_labels=list('phenotype'=5, 'variant'=5),
                        color=list('phenotype'='orange', 'variant'='sky.blue'),
                        shape=list('phenotype'=20, 'variant'=4),
                        axis_label=list('main'='variant', 'sub'='phenotype'),
                        use_ggrepel=TRUE) {
    # extract the relevant matrices from the svd object
    u  <- svd_obj$u
    vd <- (svd_obj$v) %*% (diag(svd_obj$d))

    # assign row and col names
    if(is.null(label[['phenotype']])){ label[['phenotype']] <- paste0('phenotype', 1:nrow(u)) }
    if(is.null(label[['variant']])){   label[['variant']]   <- paste0('variant',   1:nrow(vd)) }
    rownames(u)  <- label[['phenotype']]
    rownames(vd) <- label[['variant']]
    colnames(u)  <- 1:length(svd_obj$d)
    colnames(vd) <- 1:length(svd_obj$d)

    # configure plotting colors
    cb.colors <- get_cb_colors()
    for(k in names(color)){
      if(color[[k]] %in% names(cb.colors)){
        color[[ sprintf('plot_%s', k) ]] <- cb.colors[[ color[[ k ]] ]]
      }else{
        color[[ sprintf('plot_%s', k) ]] <- color[[ k ]]
      }
    }

    # convert the matrices into data frames
    df_u  <- u  %>% as.data.frame() %>% rename('PC_x' := component$x, 'PC_y' := component$y) %>%
    select(PC_x, PC_y) %>% rownames_to_column('label') %>%
    mutate(label = if_else(rank(-(PC_x**2+PC_y**2))<=n_labels[['phenotype']], label, ''))

    df_vd <- vd %>% as.data.frame() %>% rename('PC_x' := component$x, 'PC_y' := component$y) %>%
    select(PC_x, PC_y) %>% rownames_to_column('label') %>%
    mutate(label = if_else(rank(-(PC_x**2+PC_y**2))<=n_labels[['variant']], label, ''))

    # scale u (data on sub-axis) to map to the main-axis
    lim_u_abs   <- 1.1 * max(abs(df_u  %>% select(PC_x, PC_y)))
    lim_vd_abs  <- 1.1 * max(abs(df_vd %>% select(PC_x, PC_y)))

    df_u_scaled <- df_u %>%
    mutate(
        PC_x = PC_x * (lim_vd_abs/lim_u_abs),
        PC_y = PC_y * (lim_vd_abs/lim_u_abs)
    )

    if(! use_ggrepel){
      # generate plot without ggrepel. This is useful when you'd like
      # to conver the plot into plotly object using gplotly.
        p <- ggplot() +
        layer(
          # scatter plot for (VD)
            data=df_vd, mapping=aes(x=PC_x, y=PC_y, shape='variant', label=label),
            geom='point', stat = "identity", position = "identity",
            params=list(size=1, color=color[['plot_variant']])
        )+
        layer(
          # segments (lines) for U
            data=df_u_scaled,
            mapping=aes(x=0, y=0, xend=PC_x, yend=PC_y),
            geom='segment', stat = "identity", position = "identity",
            params=list(size=1, color=color[['plot_phenotype']], alpha=.1)
        )+
        layer(
          # scatter plot for U
            data=df_u_scaled,
            mapping=aes(x=PC_x, y=PC_y, shape='phenotype', label=label),
            geom='point', stat = "identity", position = "identity",
            params=list(size=1, color=color[['plot_phenotype']])
        )

    } else { # use_ggrepel == TRUE
      # generate the plot with ggrepel
        p <- ggplot() +
        layer(
            data=df_vd, mapping=aes(x=PC_x, y=PC_y, shape='variant'),
            geom='point', stat = "identity", position = "identity",
            params=list(size=1, color=color[['plot_variant']])
        )+
        layer(
            data=df_u_scaled,
            mapping=aes(x=0, y=0, xend=PC_x, yend=PC_y),
            geom='segment', stat = "identity", position = "identity",
            params=list(size=1, color=color[['plot_phenotype']], alpha=.1)
        )+
        layer(
            data=df_u_scaled,
            mapping=aes(x=PC_x, y=PC_y, shape='phenotype'),
            geom='point', stat = "identity", position = "identity",
            params=list(size=1, color=color[['plot_phenotype']])
        )+
        ggrepel::geom_text_repel(
            data=bind_rows(
                df_u_scaled %>% mutate(color=color[['plot_phenotype']]),
                df_vd       %>% mutate(color=color[['plot_variant']])
            ),
            mapping=aes(x=PC_x, y=PC_y, label=label, color=color),
            size=3, force=10
        )
    }

    # configure the theme, axis, and axis labels
    p + theme_bw() +
    scale_color_manual(values=setNames(color, color)) +
    scale_shape_manual(values=shape) +
    guides(shape=FALSE,color=FALSE) +
    scale_x_continuous(
        sprintf('Component %s (%s [%s])', component$x, axis_label[['main']], color[['variant']]),
        limits = c(-lim_vd_abs, lim_vd_abs),
        sec.axis = sec_axis(
            ~ . * (lim_u_abs/lim_vd_abs),
            name = sprintf('Component %s (%s [%s])', component$x, axis_label[['sub']], color[['phenotype']])
        )
    ) +
    scale_y_continuous(
        sprintf('Component %s (%s [%s])', component$y, axis_label[['main']], color[['variant']]),
        limits = c(-lim_vd_abs, lim_vd_abs),
        sec.axis = sec_axis(
            ~ . * (lim_u_abs/lim_vd_abs),
            name = sprintf('Component %s (%s [%s])', component$y, axis_label[['sub']], color[['phenotype']])
        )
    )
}
junyangq/multiSnpnet documentation built on Oct. 19, 2023, 8:22 p.m.