R/mefisto.R

Defines functions interpolate_factors .plot_factors_vs_cov_2d .plot_factors_vs_cov_1d plot_factors_vs_cov plot_data_vs_cov plot_interpolation_vs_covariate plot_sharedness plot_smoothness plot_group_kernel get_default_mefisto_options get_covariates set_covariates

Documented in get_covariates get_default_mefisto_options interpolate_factors plot_data_vs_cov plot_factors_vs_cov plot_group_kernel plot_interpolation_vs_covariate plot_sharedness plot_smoothness set_covariates

##########################################################################
## Functions to use continuous covariates, as part of the MEFISTO framework ##
##########################################################################

#' @title Add covariates to a MOFA model
#' @name set_covariates
#' @description Function to add continuous covariate(s) to a \code{\link{MOFA}} object for training with MEFISTO
#' @param object an untrained \code{\link{MOFA}}
#' @param covariates Sample-covariates to be passed to the model.
#' This can be either:
#' \itemize{
#'   \item{a character, specifying columns already present in the samples_metadata of the object}
#'   \item{a data.frame with columns "sample", "covariate", "value". Sample names need to match those present in the data}
#'   \item{a matrix with smaples in columns and covariate(s) in row(s)}
#'  }
#' Note that the covariate should be numeric and continuous.
#' @return Returns an untrained \code{\link{MOFA}} with covariates filled in the corresponding slots
#' @details To activate the functional MEFISTO framework, specify mefisto_options when preparing the training using \code{prepare_mofa} 
#' @export
#' @examples
#' #' # Simulate data
#' dd <- make_example_data(sample_cov = seq(0,1,length.out = 100), n_samples = 100, n_factors = 4)
#' 
#' # Create MOFA object
#' sm <- create_mofa(data = dd$data)
#' 
#' # Add a covariate
#' sm <- set_covariates(sm, covariates = dd$sample_cov)
#' sm

set_covariates <- function(object, covariates) {

  # Sanity checks
  if (!is(object, "MOFA")) 
    stop("'object' has to be an instance of MOFA")
  if (object@status=="trained") 
    stop("The model is already trained! Covariates must be added before training")
  
  # get sample names
  samples_data <- lapply(object@data[[1]], colnames)
  # samples <- unlist(samples_data)
  samples_data_vec <- unlist(samples_names(object))
  
  # covariates passed as characters: extract from the metadata as dataframe
  if (is(covariates, "character")) {
    if (!all(covariates %in% colnames(samples_metadata(object)))) {
      stop("Columns specified in covariates do not exist in the MOFA object metadata slot.")
    }
    covariates <- samples_metadata(object)[,c("sample",covariates),drop=FALSE]

    covariates <- gather(covariates, key = "covariate", value = "value", -sample)
    if(!is.numeric(covariates$value)){
      stop("Covariates need to be numeric")
    }
    # TO-DO: Check that they continuous
  
  # covariates passed in data.frame format
  }
  
  if (any(class(covariates) %in% c("data.frame", "tibble", "Data.Frame"))) { # TO-DO: USE is()
    if (!all(c("sample", "covariate", "value") %in% colnames(covariates)))
      stop("If covariates is provided as data.frame it needs to contain the columns: sample, covariate, value")
    if (!is.numeric(covariates$value)) {
      stop("Values in covariates need to be numeric")
    }
    samples <- covariates$sample
    # covariates <- covariates[!duplicated(covariates), ]
    covariates <- reshape2::acast(covariates, covariate ~ sample)
    
  # covariates passed in matrix format
  # TO-DO: CHECK THIS
  } else if (all(is.numeric(covariates)) || class(covariates) %in% c("dgTMatrix", "dgCMatrix")) {
    samples <- colnames(covariates)
    if (!is.null(samples)) {
      if(!(all(samples %in% samples_data_vec) && all(samples_data_vec %in% samples)))
        stop("Sample names of the data and the sample covariates do not match.")
      covariates <- covariates[ , samples_data_vec, drop = FALSE]
    } else {
      # warnings and checks if no matching sample names
      if(sum(object@dimensions[['N']]) != ncol(covariates))
        stop("Number of columns in sample covariates does not match the number of samples")
      if(!is.null(samples_data) && length(samples_data_vec) > 0) {
        warning("No sample names in covariates - we will use the sample names in data. Please ensure that the order matches.")
        colnames(covariates) <- samples
      } else {
        stop("No sample names found!")
      }
    }
    
  # covariates format not recognised
  } else {
    stop("covariates needs to be a character vector, a dataframe, a matrix or NULL.")
  }
    
  # Set covariate dimensionality
  object@dimensions[["C"]] <- nrow(covariates)
    
  # Set covariate names
  if (is.null(rownames(covariates))) {
    message("No covariates names provided - using generic: covariate1, covariate2, ...")
    rownames(covariates) <- paste0("covariate", seq_len(nrow(covariates)))
  }
  
  # split covariates by groups
  covariates <- lapply(samples_names(object), function(i)   covariates[, i, drop = FALSE])
  names(covariates) <- groups_names(object)
  
  # Sanity checks
  stopifnot(all(sapply(covariates, ncol) == object@dimensions[["N"]]))
  
  # add covariates to the MOFA object
  object@covariates <- covariates
  
  return(object)
}


#' @title Get sample covariates
#' @name get_covariates
#' @description Function to extract the covariates from a \code{\link{MOFA}} object using MEFISTO.
#' @param object a \code{\link{MOFA}} object.
#' @param covariates character vector with the covariate name(s), or numeric vector with the covariate index(es). 
#' @param as.data.frame logical indicating whether to output the result as a long data frame, default is \code{FALSE}.
#' @param warped logical indicating whether to extract the aligned covariates
#' @return a matrix with dimensions (samples,covariates). If \code{as.data.frame} is \code{TRUE}, a long-formatted data frame with columns (sample,factor,value)
#' @export
#' @examples 
#' # Using an existing trained model
#' file <- system.file("extdata", "MEFISTO_model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' covariates <- get_covariates(model)

get_covariates <- function(object, covariates = "all", as.data.frame = FALSE, warped = FALSE) {
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  
  # Get and check covariate names
  covariates <- .check_and_get_covariates(object, covariates)
  
  # Get covariates
  if(warped){
    sample_cov <- lapply(object@covariates_warped, function(cmat) cmat[covariates,,drop=FALSE])
  } else {
    sample_cov <- lapply(object@covariates, function(cmat) cmat[covariates,,drop=FALSE])
    
  }
  
  if (as.data.frame) {
    if(!is.null(rownames(sample_cov[[1]]))){
      nms <- rownames(sample_cov[[1]]) 
    } else {
      nms <- paste0("covariate_", seq_along(covariates))
    }
    sample_cov <- Reduce(cbind, sample_cov) # remove group info
    sample_cov <- melt(sample_cov, varnames = c("covariate", "sample"))
  }
  
  return(sample_cov)
}


#' @title Get default options for MEFISTO covariates
#' @name get_default_mefisto_options
#' @description Function to obtain the default options for the usage of MEFISTO covariates with MEFISTO
#' @param object an untrained \code{\link{MOFA}} object
#' @details The options are the following: \cr
#' \itemize{
#'  \item{\strong{scale_cov}:}  logical: Scale covariates?
#'  \item{\strong{start_opt}:} integer: First iteration to start the optimisation of GP hyperparameters
#'  \item{\strong{n_grid}:} integer: Number of points for the grid search in the optimisation of GP hyperparameters
#'  \item{\strong{opt_freq}:} integer: Frequency of optimisation of GP hyperparameters
#'  \item{\strong{sparseGP}:} logical: Use sparse GPs to speed up the optimisation of the GP parameters?
#'  \item{\strong{frac_inducing}:} numeric between 0 and 1: Fraction of samples to use as inducing points (only relevant if sparseGP is \code{TRUE})
#'  \item{\strong{warping}:}   logical: Activate warping functionality to align covariates between groups (requires a multi-group design)
#'  \item{\strong{warping_freq}:} numeric: frequency of the warping (only relevant if warping is \code{TRUE})
#'  \item{\strong{warping_ref}:} A character specifying the reference group for warping (only relevant if warping is \code{TRUE})
#'  \item{\strong{warping_open_begin}:} logical: Warping: Allow for open beginning? (only relevant warping is \code{TRUE})
#'  \item{\strong{warping_open_end}:} logical: Warping: Allow for open end? (only relevant warping is \code{TRUE})
#'  \item{\strong{warping_groups}:} Assignment of groups to classes used for alignment (advanced option). 
#'  Needs to be a vector of length number of samples, e.g. a column of samples_metadata, which needs to have the same value within each group.
#'  By default groups are used specified in `create_mofa`.
#'  \item{\strong{model_groups}:} logical: Model covariance structure across groups (for more than one group, otherwise FALSE)? If FALSE, we assume the same patterns in all groups. 
#'  \item{\strong{new_values}:} Values for which to predict the factor values (for interpolation / extrapolation). 
#'  This should be numeric matrix in the same format with covariate(s) in rows and new values in columns.
#'  Default is NULL, leading to no interpolation.
#' }
#' @return Returns a list with default options for the MEFISTO covariate(s) functionality.
#' @importFrom utils modifyList
#' @export
#' @examples 
#' # generate example data
#' dd <- make_example_data(sample_cov = seq(0,1,length.out = 200), n_samples = 200,
#' n_factors = 4, n_features = 200, n_views = 4, lscales = c(0.5, 0.2, 0, 0))
#' # input data
#' data <- dd$data
#' # covariate matrix with samples in columns
#' time <- dd$sample_cov
#' rownames(time) <- "time"
#' 
#' # create mofa and set covariates
#' sm <- create_mofa(data = dd$data)
#' sm <- set_covariates(sm, covariates = time)
#' 
#' MEFISTO_opt <- get_default_mefisto_options(sm)

get_default_mefisto_options <- function(object) {
  
  mefisto_options <- list(
    
    # Standard options
    scale_cov = FALSE,            # (logical) Scale covariates?
    start_opt = 20,              # (integer) First iteration to start the optimisation of GP hyperparameters
    n_grid = 20,                 # (integer) Number of points for the grid search in the optimisation of GP hyperparameters
    opt_freq = 10,               # (integer) Frequency of optimisation of GP hyperparameters
    model_groups = TRUE,         # (logical) model covariance structure across groups
    
    # sparse GP options
    sparseGP = FALSE,            # (logical) Use sparse GPs to speed up the optimisation of the GP parameters?
    frac_inducing = 0.75,       # (numeric) Fraction of samples to use as inducing points
    
    # warping
    warping = FALSE,             # (logical) Activate warping functionality to align covariates between groups (requires a multi-group design)
    warping_freq = 20,           # (numeric) Warping: frequency of the optimisation
    warping_ref = groups_names(object)[[1]],          # (character) Warping: reference group
    warping_open_begin = TRUE,   # (logical) Warping: Allow for open beginning?
    warping_open_end = TRUE,      # (logical) Warping: Allow for open ending?
    warping_groups = NULL,
    
    new_values = NULL            # new values if interpolation/extrapolation is wanted
    
  )
  
  # model_groups is set to FALSE if only one group present
  if (object@dimensions$G == 1)
    mefisto_options$model_groups <- FALSE
    
  # if mefisto_options already exist, replace the default values but keep the additional ones
  if (length(object@mefisto_options)>0)
    mefisto_options <- modifyList(mefisto_options, object@mefisto_options)
  
  return(mefisto_options)
}



#' @title Heatmap plot showing the group-group correlations per factor
#' @name plot_group_kernel
#' @description Heatmap plot showing the group-group correlations inferred by the model per factor
#' @param object a trained \code{\link{MOFA}} object using MEFISTO.
#' @param factors character vector with the factors names, or numeric vector indicating the indices of the factors to use
#' @param groups character vector with the groups names, or numeric vector with the indices of the groups of samples to use, or "all" to use samples from all groups.
#' @param ... additional parameters that can be passed to  \code{pheatmap} 
#' @details The heatmap gives insight into the clustering of the patterns that factors display along the covariate in each group. 
#' A correlation of 1 indicates that the module caputred by a factor shows identical patterns across groups, a correlation of zero that it shows distinct patterns,
#' a negative correlation that the patterns go in opposite directions.
#' @return Returns a \code{ggplot,gg} object containing the heatmaps
#' @import pheatmap 
#' @import cowplot
#' @export
#' @examples
#' # Using an existing trained model on simulated data
#' file <- system.file("extdata", "MEFISTO_model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' plot_group_kernel(model)
plot_group_kernel <- function(object, factors = "all", groups = "all", ...) {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  
  # Define factors
  factors <- .check_and_get_factors(object, factors)

  # Define groups
  groups <- .check_and_get_groups(object, groups)
  
  # Get group kernels
  Kg <- get_group_kernel(object)
  
  hmlist <- lapply(factors, function(f){
    tmp <- Kg[[f]][groups,groups]
    # set breaks for heatmaps
    ncols <- 100
    seq_breaks <- c(seq(-1, 0, 1/ncols * 2), seq(0, 1, 1/ncols * 2)[-1])
    
    p <- pheatmap::pheatmap(tmp, color = rev(colorRampPalette((RColorBrewer::brewer.pal(n = 7, name ="RdBu")))(ncols)), breaks = seq_breaks, silent = TRUE,...)
    
    p$gtable
  })
  # subset to groups
  
  p <- cowplot::plot_grid(plotlist = hmlist)

  return(p)
}



#' @title Barplot showing the smoothness per factor
#' @name plot_smoothness
#' @description Barplot indicating a smoothness score (between 0 (non-smooth) and 1 (smooth)) per factor
#' @param object a trained \code{\link{MOFA}} object using MEFISTO.
#' @param factors character vector with the factors names, or numeric vector indicating the indices of the factors to use
#' @param color for the smooth part of the bar
#' @details The smoothness score is given by the scale parameter for the underlying Gaussian process of each factor.
#' @return Returns a \code{ggplot2} object
#' @import ggplot2
#' @importFrom tidyr gather
#' @export
#' @examples 
#' # Using an existing trained model
#' file <- system.file("extdata", "MEFISTO_model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' smoothness_bars <- plot_smoothness(model)

plot_smoothness <- function(object, factors = "all", color = "cadetblue") {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  
  # Define factors
  factors <- .check_and_get_factors(object, factors)
  
  # Get scale parameters
  ss <- get_scales(object)[factors]
  df <- data.frame(factor = names(ss), smooth = ss, non_smooth = 1- ss)
  df$factor <- factor(df$factor, levels=factors)
  df <- gather(df, -factor, key = "smoothness", value = "value")
  gg_bar <- ggplot(df, aes(x= 1, y = value, fill = smoothness)) +
    geom_bar(stat="identity") +
    facet_wrap(~factor, nrow = 1) +
    theme_void() + coord_flip() +
    guides(fill=FALSE) + scale_fill_manual(values = c("non_smooth" = "gray", "smooth" = color)) +
    geom_text(x=1, y = 0.5, label = "smoothness", size = 3)

  return(gg_bar)
}


#' @title Barplot showing the sharedness per factor
#' @name plot_sharedness
#' @description Barplot indicating a sharedness score (between 0 (non-shared) and 1 (shared)) per factor
#' @param object a trained \code{\link{MOFA}} object using MEFISTO.
#' @param factors character vector with the factors names, or numeric vector indicating the indices of the factors to use
#' @param color for the shared part of the bar
#' @details The sharedness score is calculated as the distance of the learnt group correlation matrix to the identity matrix
#'  in terms of the mean absolute distance on the off-diagonal elements.
#' @return Returns a \code{ggplot2} object
#' @import ggplot2
#' @export

plot_sharedness <- function(object, factors = "all", color = "#B8CF87") {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  if (object@dimensions$G == 1) stop("'object' has only one group, more than one group are required to determine sharedness.")
  
  # Define factors
  factors <- .check_and_get_factors(object, factors)
  
  # Get group kernels
  Kgs <- get_group_kernel(object)[factors]
  
  # Calculate distance
  idmat <- diag(1, ncol(Kgs[[1]]))
  gr <- sapply(Kgs, function(k) mean(abs(k - idmat)[lower.tri(idmat)]))
  
  # make plot
  df <- data.frame(factor = names(gr), group = gr, non_group = 1-gr)
  df$factor <- factor(df$factor, levels=factors)
  df <- gather(df, -factor, key = "sharedness", value = "value")
  df <- mutate(df, sharedness = factor(sharedness, levels = rev(c("group", "non_group"))))
  gg_bar <- ggplot(df, aes(x= 1, y=value, fill = sharedness)) + geom_bar(stat="identity") +
    facet_wrap(~factor, nrow = 1) +
    theme_void() + coord_flip() +
    guides(fill=FALSE) + scale_fill_manual(values = c("non_group" = "gray", "group" = color)) +
    geom_text(x=1, y = 0.5, label = "sharedness", size = 3)
  
  return(gg_bar)
}

#' @title Plot interpolated factors versus covariate (1-dimensional)
#' @name plot_interpolation_vs_covariate
#' @description make a plot of interpolated covariates versus covariate
#' @param object a trained \code{\link{MOFA}} object using MEFISTO.
#' @param covariate covariate to use for plotting
#' @param factors character or numeric specifying the factor(s) to plot, default is "all"
#' @param only_mean show only mean or include uncertainties?
#' @param show_observed include observed factor values as dots on the plot
#' @details to be filled
#' @return Returns a \code{ggplot2} object
#' @import ggplot2
#' @export
#' @examples 
#' # Using an existing trained model
#' file <- system.file("extdata", "MEFISTO_model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' model <- interpolate_factors(model, new_values = seq(0,1.1,0.1))
#' plot_interpolation_vs_covariate(model, covariate = "time", factors = 1)

plot_interpolation_vs_covariate <- function(object, covariate = 1, factors = "all", only_mean = TRUE, show_observed = TRUE){

  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")

  # get and check covariate
  covariate <- .check_and_get_covariates(object, covariate)
  
  # get and check factor
  factors <- .check_and_get_factors(object, factors)
  
  # get interpolated factor
  df <- get_interpolated_factors(object, as.data.frame = TRUE, only_mean = only_mean)
  df <- filter(df, factor %in% factors)
  df$factor <- factor(df$factor, levels = factors)
  # calculate ribbon borders
  if(!only_mean) {
    df <- df %>% mutate(sd = sqrt(variance), ymin = mean -1.96 * sd, ymax = mean + 1.96 * sd)
  }

  if(show_observed) {
    # add the factor values of the observed time point  to the plot
    df_observed <- plot_factors_vs_cov(object, covariates = covariate, return_data = TRUE)
    df_observed <- filter(df_observed, factor %in% factors)
    df_observed$factor <- factor(df_observed$factor, levels = factors)
  }

  gg_interpol <- ggplot(df, aes(x=.data[[covariate]], y = .data$mean, col = .data$group)) +
    geom_line(aes(y=mean,  col = group)) +
    facet_wrap(~ factor) + theme_classic() + ylab("factor value")

  if(show_observed) {
    gg_interpol <- gg_interpol + geom_point(data = df_observed, aes(x= value.covariate,
                                                                  y = value.factor, col = group), size = 1)
  }
  if(!only_mean) {
    gg_interpol <- gg_interpol + geom_ribbon(aes(ymin=ymin, ymax = ymax, fill = group),
                                             alpha = .2, col = "gray", size = 0.1)
  }

  gg_interpol
}




#' @title Scatterplots of feature values against sample covariates
#' @name plot_data_vs_cov
#' @description Function to do a scatterplot of features against sample covariate values.
#' @param object a \code{\link{MOFA}} object using MEFISTO.
#' @param covariate string with the covariate name or a samples_metadata column, or an integer with the index of the covariate
#' @param warped logical indicating whether to show the aligned covariate (default: TRUE), 
#' only relevant if warping has been used to align multiple sample groups
#' @param factor string with the factor name, or an integer with the index of the factor to take top features from
#' @param view string with the view name, or an integer with the index of the view. Default is the first view.
#' @param groups groups to plot. Default is "all".
#' @param features if an integer (default), the total number of features to plot (given by highest weights). If a character vector, a set of manually-defined features.
#' @param sign can be 'positive', 'negative' or 'all' (default) to show only features with highest positive, negative or all weights, respectively.
#' @param color_by specifies groups or values (either discrete or continuous) used to color the dots (samples). This can be either: 
#' \itemize{
#' \item the string "group": dots are coloured with respect to their predefined groups.
#' \item a character giving the name of a feature that is present in the input data 
#' \item a character giving the same of a column in the sample metadata slot
#' \item a vector of the same length as the number of samples specifying the value for each sample. 
#' \item a dataframe with two columns: "sample" and "color"
#' }
#' @param shape_by specifies groups or values (only discrete) used to shape the dots (samples). This can be either: 
#' \itemize{
#' \item the string "group": dots are shaped with respect to their predefined groups.
#' \item a character giving the name of a feature that is present in the input data 
#' \item a character giving the same of a column in the sample metadata slot
#' \item a vector of the same length as the number of samples specifying the value for each sample. 
#' \item a dataframe with two columns: "sample" and "shape"
#' }
#' @param legend logical indicating whether to add a legend
#' @param dot_size numeric indicating dot size (default is 5).
#' @param text_size numeric indicating text size (default is 5).
#' @param stroke numeric indicating the stroke size (the black border around the dots, default is NULL, inferred automatically).
#' @param alpha numeric indicating dot transparency (default is 1).
#' @param add_lm logical indicating whether to add a linear regression line for each plot
#' @param lm_per_group logical indicating whether to add a linear regression line separately for each group
#' @param imputed logical indicating whether to include imputed measurements
#' @param return_data logical indicating whether to return a data frame instead of a plot
#' @details One of the first steps for the annotation of factors is to visualise the weights using \code{\link{plot_weights}} or \code{\link{plot_top_weights}}
#' and inspect the relationshio of the factor to the covariate(s) using  \code{\link{plot_factors_vs_cov}}.
#' However, one might also be interested in visualising the direct relationship between features and covariate(s), rather than looking at "abstract" weights and
#' possibly look at the interpolated and extrapolated values by setting imputed to True.
#' @import ggplot2
# #' @importFrom ggpubr stat_cor
#' @importFrom dplyr left_join
#' @importFrom utils tail
#' @importFrom stats quantile
#' @return Returns a \code{ggplot2} object or the underlying dataframe if return_data is set to \code{TRUE}.
#' @export
#' @examples 
#' # Using an existing trained model
#' file <- system.file("extdata", "MEFISTO_model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' plot_data_vs_cov(model, factor = 3, features = 2)

plot_data_vs_cov <- function(object, covariate = 1, warped = TRUE, factor = 1, view = 1, groups = "all", features = 10, sign = "all",
                              color_by = "group", legend = TRUE, alpha = 1, shape_by = NULL, stroke = NULL,
                              dot_size = 2.5, text_size = NULL, add_lm = FALSE, lm_per_group = FALSE, imputed = FALSE, return_data = FALSE) {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  stopifnot(length(factor)==1)
  stopifnot(length(covariate)==1)
  stopifnot(length(view)==1)
  if (lm_per_group) add_lm = TRUE
  
  # Define views, factors and groups
  groups <- .check_and_get_groups(object, groups)
  factor <- .check_and_get_factors(object, factor)
  view <- .check_and_get_views(object, view)
  
  # Check and fetch covariates
  df1 <- get_covariates(object, covariate, as.data.frame = TRUE, warped = warped) 
  covariate_name <- unique(df1$covariate)
  if(!warped){
    covariate_name <- paste(covariate_name, "(unaligned)")
  }
  
  # Collect relevant data
  N <- get_dimensions(object)[["N"]]
  W <- get_weights(object)[[view]][,factor]
  
  # Get features
  if (sign=="all") {
    W <- abs(W)
  } else if (sign=="positive") {
    W <- W[W>0]
  } else if (sign=="negative") {
    W <- W[W<0]
  }
  
  if (is(features, "numeric")) {
    if (length(features) == 1) {
      features <- names(tail(sort(abs(W)), n=features))
    } else {
      features <- names(sort(-abs(W))[features])
    }
    stopifnot(all(features %in% features_names(object)[[view]]))  
  } else if (is(features, "character")) {
    stopifnot(all(features %in% features_names(object)[[view]]))
  } else {
    stop("Features need to be either a numeric or character vector")
  }

  # Set group/color/shape
  if (length(color_by)==1 && is.character(color_by)) color_name <- color_by
  if (length(shape_by)==1 && is.character(shape_by)) shape_name <- shape_by
  color_by <- .set_colorby(object, color_by)
  shape_by <- .set_shapeby(object, shape_by)
  
  # Merge factor values with color and shape information
  df1 <- merge(df1, color_by, by="sample")
  df1 <- merge(df1, shape_by, by="sample")
  
  # Create data frame 
  foo <- list(features); names(foo) <- view
  if (imputed) {
    df2 <- get_imputed_data(object, groups = groups, views = view, features = foo, as.data.frame = TRUE)
  } else {
    df2 <- get_data(object, groups = groups, features = foo, as.data.frame = TRUE)
  }
  
  df2$sample <- as.character(df2$sample)
  df <- left_join(df1, df2, by = "sample", suffix = c(".covariate",".data"))
  
  # (Q) Remove samples with missing values in Factor values
  df <- df[!is.na(df$value.covariate) & !is.na(df$value.data) ,]
  
  if(return_data){
    return(df)
  }
  
  # Set stroke
  if (is.null(stroke)) {
    stroke <- .select_stroke(N=length(unique(df$sample)))
  }
  
  # Set Pearson text size
  if (add_lm && is.null(text_size)) {
    text_size <- .select_pearson_text_size(N=length(unique(df$feature)))
  }
  
  # Set axis text size
  axis.text.size <- .select_axis.text.size(N=length(unique(df$feature)))
  
  # Generate plot
  p <- ggplot(df, aes(x = .data[["value.covariate"]], y = .data[["value.data"]])) + 
    geom_point(aes(fill = .data$color_by, shape = .data$shape_by), colour = "black", size = dot_size, stroke = stroke, alpha = alpha) +
    labs(x=covariate_name, y="") +
    facet_wrap(~feature, scales="free_y") +
    theme_classic() + 
    theme(
      axis.text = element_text(size = rel(axis.text.size), color = "black"), 
      axis.title = element_text(size = rel(1.0), color="black")
    )
  
  # Add linear regression line
  if (add_lm) {
    if (lm_per_group && length(groups)>1) {
      p <- p +
        stat_smooth(formula=y~x, aes(color=.data$group), method="lm", alpha=0.4) +
        ggpubr::stat_cor(aes(color=.data$group, label = .data[["..r.label.."]]), method = "pearson", label.sep="\n", output.type = "latex", size = text_size)# +
      # guides(color = FALSE)
    } else {
      p <- p +
        stat_smooth(formula=y~x, method="lm", color="grey", fill="grey", alpha=0.4) +
        ggpubr::stat_cor(method = "pearson", label.sep="\n", output.type = "latex", size = text_size, color = "black")
    }
  }
  
  # Add legend
  p <- .add_legend(p, df, legend, color_name, shape_name)
  
  return(p)
}


#' @title Scatterplots of a factor's values againt the sample covariates
#' @name plot_factors_vs_cov
#' @description  Scatterplots of a factor's values againt the sample covariates
#' @param object a trained \code{\link{MOFA}} object using MEFISTO.
#' @param factors character or numeric specifying the factor(s) to plot, default is "all"
#' @param covariates specifies sample covariate(s) to plot against:
#' (1) a character giving the name of a column present in the sample covariates or sample metadata.
#' (2) a character giving the name of a feature present in the training data.
#' (3) a vector of the same length as the number of samples specifying continuous numeric values per sample.
#' Default is the first sample covariates in covariates slot
#' @param warped logical indicating whether to show the aligned covariate (default: TRUE), 
#' only relevant if warping has been used to align multiple sample groups
#' @param scale logical indicating whether to scale factor values.
#' @param show_missing  (for 1-dim covariates) logical indicating whether to include samples for which \code{shape_by} or \code{color_by} is missing
#' @param color_by (for 1-dim covariates) specifies groups or values used to color the samples. This can be either:
#' (1) a character giving the name of a feature present in the training data.
#' (2) a character giving the same of a column present in the sample metadata.
#' (3) a vector of the same length as the number of samples specifying discrete groups or continuous numeric values.
#' @param shape_by  (for 1-dim covariates) specifies groups or values used to shape the samples. This can be either:
#' (1) a character giving the name of a feature present in the training data, 
#' (2) a character giving the same of a column present in the sample metadata.
#' (3) a vector of the same length as the number of samples specifying discrete groups.
#' @param color_name  (for 1-dim covariates) name for color legend.
#' @param shape_name  (for 1-dim covariates) name for shape legend.
#' @param dot_size  (for 1-dim covariates) numeric indicating dot size.
#' @param alpha  (for 1-dim covariates) numeric indicating dot transparency.
#' @param stroke  (for 1-dim covariates) numeric indicating the stroke size
#' @param legend  (for 1-dim covariates) logical indicating whether to add legend.
#' @param rotate_x (for spatial, 2-dim covariates) Rotate covariate on x-axis 
#' @param rotate_y (for spatial, 2-dim covariates) Rotate covariate on y-axis
#' @param return_data logical indicating whether to return the data frame to plot instead of plotting
#' @param show_variance  (for 1-dim covariates) logical indicating whether to show the marginal variance of inferred factor values 
#' (only relevant for 1-dimensional covariates)
#' @details To investigate the factors pattern along the covariates (such as time or a spatial coordinate) 
#' this function an be used to plot a scatterplot of the factor againt the values of each covariate
#' @return Returns a \code{ggplot2} object
#' @import ggplot2 dplyr
#' @importFrom stats complete.cases
#' @importFrom tidyr spread
#' @importFrom magrittr %>% set_colnames
#' @export
#' @examples 
#' # Using an existing trained model
#' file <- system.file("extdata", "MEFISTO_model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' plot_factors_vs_cov(model)

plot_factors_vs_cov <- function(object, factors = "all", covariates = NULL, warped = TRUE, show_missing = TRUE, scale = FALSE,
                                color_by = NULL, shape_by = NULL, color_name = NULL, shape_name = NULL,
                                dot_size = 1.5, alpha = 1, stroke = NULL, legend = TRUE,
                                rotate_x = FALSE, rotate_y = FALSE, return_data = FALSE, show_variance = FALSE) {
  
  # Sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  
  # Define covariates
  if (is.null(covariates)) {
    if (!.hasSlot(object, "covariates") || any(object@dimensions[["C"]] < 1, is.null(object@covariates)))  
      stop("No covariates found in object. Please specify one.")
    covariates <- covariates_names(object)
  }
  
  # Get factors
  factors <- .check_and_get_factors(object, factors)
  Z <- get_factors(object, factors=factors, as.data.frame=TRUE)
  
  # Remove samples with missing values
  Z <- Z[complete.cases(Z),]
  
  # Get covariates
  df <- get_covariates(object, covariates, as.data.frame = TRUE, warped = warped) %>%
    merge(Z, by="sample", suffixes = c(".covariate",".factor"))
  
  # Remember color_name and shape_name if not provided
  if (!is.null(color_by) && (length(color_by) == 1) && is.null(color_name))
    color_name <- color_by
  if (!is.null(shape_by) && (length(shape_by) == 1) && is.null(shape_name))
    shape_name <- shape_by
  
  # Set color and shape
  color_by <- .set_colorby(object, color_by)
  shape_by <- .set_shapeby(object, shape_by )
  
  # Merge factor values with color and shape information
  df <- df %>%
    merge(color_by, by="sample") %>%
    merge(shape_by, by="sample") %>%
    mutate(shape_by = as.character(shape_by))
  
  # Remove missing values
  if (!show_missing) df <- filter(df, !is.na(color_by) && !is.na(shape_by))
  
  # Return data if requested instead of plotting
  if (return_data) return(df)
  
  # Set stroke
  if (is.null(stroke)) stroke <- .select_stroke(N=length(unique(df$sample)))
  
  # Select 1D or 2D plots
  if (length(covariates) == 1) {
    
    # Include marginal variance
    if (show_variance) {
      if("E2" %in% names(object@expectations$Z)){
        ZZ = object@expectations$Z$E2
        ZZ <- reshape2::melt(ZZ, na.rm=TRUE)
        colnames(ZZ) <- c("sample", "factor", "E2")
        df <- left_join(df, ZZ, by = c("sample", "factor"))
        df <- mutate(df, var = E2 - value^2)
      } else {
        show_variance <- FALSE
        warning("No second moments saved in the trained model - variance can not be shown.")
      }
    }
    p <- .plot_factors_vs_cov_1d(df,
            color_name = color_name,
            shape_name = shape_name,
            scale = scale, 
            dot_size = dot_size, 
            alpha = alpha, 
            stroke = stroke,
            show_variance = show_variance,
            legend = legend,
            warped = warped
          ) 
  } else if (length(covariates) == 2) {
    p <- .plot_factors_vs_cov_2d(df,
           scale = scale, 
           rotate_x = rotate_x,
           rotate_y = rotate_y
          )
  } else {
    stop("too many covariates provided")
  }
  
  return(p)
}


.plot_factors_vs_cov_1d <- function(df, color_name = "", shape_name = "", scale = FALSE, dot_size = 1.5, alpha = 1, stroke = 1, show_variance = FALSE, legend = TRUE, warped = TRUE) {
  
  # Sanity checks
  stopifnot(length(unique(df$covariate))==1)
  
  covariate_name <- unique(df$covariate)
  if(!warped){
    covariate_name <- paste(covariate_name, "(unaligned)")
  }
  
  
  # Scale values from 0 to 1
  if (scale) {
    df <- df %>% 
      group_by(factor) %>%
      mutate(value_scaled = value.factor/max(abs(value.factor)))
    if(show_variance) df <- mutate(df, var = var/(max(abs(value.factor))^2))
    df <- df %>% 
      mutate(value.factor = value_scaled) %>%
      select(-value_scaled) %>%
      ungroup
  }
  
  # Generate plot
  p <- ggplot(df, aes(x=value.covariate, y=value.factor)) + 
    geom_point(aes(fill = .data$color_by, shape = .data$shape_by), colour="black", stroke = stroke, size=dot_size, alpha=alpha) +
    facet_wrap(~ factor) +
    theme_classic() +
    theme(
      axis.text = element_text(size = rel(0.9), color = "black"), 
      axis.title = element_text(size = rel(1.2), color = "black"), 
      axis.line = element_line(color = "black", linewidth = 0.5), 
      axis.ticks = element_line(color = "black", linewidth = 0.5)
    ) + xlab(covariate_name) + ylab("factor value")
  
  if (show_variance){
    p <- p + geom_errorbar(aes(ymin = value - sqrt(var)*1.96, ymax =value + sqrt(var)*1.96), col = "red", alpha = 0.7)
  }
  
  p <- .add_legend(p, df, legend, color_name, shape_name)
  
  return(p)
}

.plot_factors_vs_cov_2d <- function(df, scale = FALSE,
                                    rotate_x = FALSE, rotate_y= FALSE) {
  
  # Sanity checks
  stopifnot(length(unique(df$covariate))==2)
  
  # pivot covariate values
  covariates_dt <- df %>%
    tidyr::pivot_wider(names_from="covariate", values_from="value.covariate") 
  
  covariates.names <- c(colnames(covariates_dt)[ncol(covariates_dt)-1], colnames(covariates_dt)[ncol(covariates_dt)])
  
  # Scale factor values from 0 to 1
  if (scale) {
    covariates_dt <- covariates_dt %>%
      group_by(factor) %>%
      mutate(value.factor = value.factor/max(abs(value.factor))) %>%
      ungroup
  }
  
  covariates_dt <- mutate(covariates_dt, color_by = value.factor) # for compatibility with .add_legend

  p <- ggplot(covariates_dt, aes(x=.data[[covariates.names[1]]],
                                        y=.data[[covariates.names[2]]],
                                        col = .data$color_by)) +
    geom_point() +
    scale_color_gradient2() + 
    geom_point(col = "gray", alpha =0.05) +
    facet_wrap( ~ factor) + coord_fixed() + 
    theme_bw() +
    theme(
      axis.text = element_text(size = rel(0.9), color = "black"),
      axis.title = element_text(size = rel(1.0), color = "black"),
      axis.line = element_line(color = "black", linewidth = 0.5),
      axis.ticks = element_line(color = "black", linewidth = 0.5)
    ) + guides(col = guide_colorbar(title = "Factor value"))
  
  if(rotate_x){
    p <- p + scale_x_reverse()
  }
  if(rotate_y){
    p <- p + scale_y_reverse()
  }
  return(p)
}


#' @title Interpolate factors in MEFISTO based on new covariate values
#' @name interpolate_factors
#' @description Function to interpolate factors in MEFISTO based on new covariate values.
#' @param object a \code{\link{MOFA}} object trained with MEFISTO options and a covariate
#' @param new_values a matrix containing the new covariate values to inter/extrapolate to. Should be
#'  in the same format as the covariated used for training.
#' @return Returns the \code{\link{MOFA}} with interpolated factor values filled in the corresponding slot (interpolatedZ)
#' @details This function requires the functional MEFISTO framework to be used in training. 
#' Use \code{set_covariates} and specify mefisto_options when preparing the training using \code{prepare_mofa}. 
#' Currenlty, only the mean of the interpolation is provided from R.
#' @export
#' @examples
#' # Using an existing trained model
#' file <- system.file("extdata", "MEFISTO_model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' model <- interpolate_factors(model, new_values = seq(0,1.1,0.01))

interpolate_factors <- function(object, new_values) {
  
  # TODO check this function
  # message("We recommend doing interpolation from python where additionally uncertainties are provided for the interpolation.")
  
  if(length(object@interpolated_Z) != 0){
    warning("Object already contains interpolated factor values, overwriting it.")
  }
  # sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  if (is.null(object@covariates)) stop("'object' does not contain any covariates.")
  if (is.null(object@mefisto_options)) stop("'object' does have MEFISTO training options.")
  if (is.null(object@expectations$Sigma)) stop("'object' does not have any expectations of Sigma.")
  if (!is.numeric(new_values)) stop("'new_values' should be numeric.")
  
  # restrutcutre 1d covariate
  if(is.null(dim(new_values))){
    new_values <- matrix(new_values, nrow = 1)
  }
  # get kernel parameters
  ls <-  get_lengthscales(object)
  Kgs <- get_group_kernel(object)
  s <- get_scales(object)
  Sigma <- object@expectations$Sigma$E
  Sigma_inv <- lapply(seq_along(factors_names(object)), function(k) solve(Sigma[k,,]))
  
  # all covariates
  if (!all(sapply(nrow(object@covariates_warped), function(c) nrow(c) == nrow(new_values)))) {
    stop("Number of covariates in new_values does not match covariates in model")
  } 
  
  # get covariates of old and new values
  if(object@mefisto_options$warping){
    old_covariates <- samples_metadata(object)[, paste(covariates_names(object), "warped", sep = "_"), drop = FALSE] %>% t()
  } else{
    old_covariates <- samples_metadata(object)[, covariates_names(object), drop = FALSE] %>% t()
    
  }
  all_covariates <- cbind(new_values, old_covariates)  %>% unique.matrix(., MARGIN = 2) 

  old_groups <-  as.character(samples_metadata(object)$group)
  old <- rbind(old_groups, old_covariates)
  all <- rbind(rep(groups_names(object), each = ncol(all_covariates)),
               t(apply(all_covariates, 1,function(x) rep(x, object@dimensions$G))))
  new <- rbind(rep(groups_names(object), each = ncol(new_values)),
               t(apply(new_values, 1,function(x) rep(x, object@dimensions$G))))
  
  oldidx <- match(data.frame(old), data.frame(all))
  newidx <- match(data.frame(new), data.frame(all))

    # get factor values
  Z <- get_factors(object) %>% Reduce(rbind,.)
  
  means <- sapply(seq_along(factors_names(object)), function(k) {
      if(ls[k] == 0 || s[k] == 0){
        means <- matrix( rep(NA, length(new_values) * object@dimensions$G), ncol = 1)
      } else {
        Kc_new <- exp(- as.matrix(dist(t(all_covariates))) ^ 2 / (2 * ls[k]^2))
        K_new_k <- s[k] * Kgs[[k]] %x% Kc_new
        mean <- K_new_k[newidx, oldidx] %*% Sigma_inv[[k]] %*% Z[,k]
      }
  }) %>% t()
  
  res <- lapply(groups_names(object), function(g){
    list(mean = means[,new[1,] == g], new_values = new_values,
         variance = rep(NA, nrow = object@dimensions$K,  # variances only provided from python
                        ncol = length(new_values)))  
    })

  
  names(res) <- groups_names(object)
  
  object@interpolated_Z <- res
  
  return(object)
}


#' @title Plot covariate alignment acorss groups
#' @name plot_alignment
#' @description Function to plot the alignment learnt by MEFISTO for the 
#' covariate values between different groups
#' @param object a \code{\link{MOFA}} object using MEFISTO with warping
#' @return ggplot object showing the alignment
#' @details This function requires the functional MEFISTO framework to be used in training. 
#' Use \code{set_covariates} and specify mefisto_options when preparing the training using \code{prepare_mofa}. 
#' @export
#' 
plot_alignment <- function(object){
  # sanity checks
  if (!is(object, "MOFA")) stop("'object' has to be an instance of MOFA")
  if (is.null(object@covariates)) stop("'object' does not contain any covariates.")
  if (is.null(object@mefisto_options)) stop("'object' does have MEFISTO training options.")
  if (!object@mefisto_options$warping) stop("No warping applied in this MOFA object")
  df_w <- get_covariates(object, 1, as.data.frame = TRUE, warped = TRUE)
  df_nw <- get_covariates(object, 1, as.data.frame = TRUE, warped = FALSE)
  
  df <- left_join(df_w, df_nw, by = c("sample"), suffix = c(".warped", ".unaligned"))
  df <- left_join(df, select(samples_metadata(object), group, sample), by = "sample")
  
  yname <- object@mefisto_options$warping_ref
  if(!yname %in% groups_names(object)){
      yname <- "reference_value"
  }

  ggplot(df, aes(y = value.warped, x = value.unaligned)) +
    geom_point() + facet_wrap(~group) + theme_bw() + ylab(yname)
}


#' @title Plot variance explained by the smooth components of the model
#' @description This function plots the variance explained by the smooth components (Gaussian processes) underlying the factors in MEFISTO across different views and groups, as specified by the user.
#' @name plot_variance_explained_by_covariates
#' @param object a \code{\link{MOFA}} object
#' @param x character specifying the dimension for the x-axis ("view", "factor", or "group").
#' @param y character specifying the dimension for the y-axis ("view", "factor", or "group").
#' @param split_by character specifying the dimension to be faceted ("view", "factor", or "group").
#' @param factors character vector with a factor name(s), or numeric vector with the index(es) of the factor(s). Default is "all".
#' @param min_r2 minimum variance explained for the color scheme (default is 0).
#' @param max_r2 maximum variance explained for the color scheme.
#' @param compare_total plot corresponding variance explained in total in addition
#' @param legend logical indicating whether to add a legend to the plot  (default is TRUE).
#' @import ggplot2
#' @importFrom cowplot plot_grid
#' @importFrom reshape2 melt
#' @details Note that this function requires the use of MEFISTO. 
#' To activate the functional MEFISTO framework, specify mefisto_options when preparing the training using \code{prepare_mofa} 
#' @return A list of \code{\link{ggplot}} objects (if \code{compare_total} is TRUE) or a single \code{\link{ggplot}} object. 
#' Consider using cowplot::plot_grid(plotlist = ...) to combine the multiple plots that this function generates.
#' @export
#' @examples
#' # load_model
#' file <- system.file("extdata", "MEFISTO_model.hdf5", package = "MOFA2")
#' model <- load_model(file)
#' plot_variance_explained_by_covariates(model)
#' 
#' # compare to toal variance explained
#' plist <- plot_variance_explained_by_covariates(model, compare_total = TRUE)
#' cowplot::plot_grid(plotlist = plist)

plot_variance_explained_by_covariates <- function(object, factors = "all",
                                              x = "view", y = "factor", split_by = NA,
                                              min_r2 = 0, max_r2 = NULL, compare_total = FALSE,
                                              legend = TRUE){
  
  # Sanity checks 
  if (length(unique(c(x, y, split_by))) != 3) { 
    stop(paste0("Please ensure x, y, and split_by arguments are different.\n",
                "  Possible values are `view`, `group`, and `factor`."))
  }
  
  # Automatically fill split_by in
  if (is.na(split_by)) split_by <- setdiff(c("view", "factor", "group"), c(x, y, split_by))
  
  views  <- .check_and_get_views(object, "all")
  groups <- .check_and_get_groups(object, "all")
  factors <- .check_and_get_factors(object, factors)
  
  # Collect relevant expectations
  W <- get_weights(object, views=views, factors=factors)
  Z <- get_factors(object, groups=groups, factors=factors)
  Z_interpol <- lapply(groups, function(g) {
    if(all(object@covariates_warped[[g]] %in% object@interpolated_Z[[g]]$new_values)){
      idx <- match(object@covariates_warped[[g]],  object@interpolated_Z[[g]]$new_values)
      mat <- t(get_interpolated_factors(object, only_mean = TRUE)[[g]]$mean)[idx,]
    } else {
      message("No interpolations found in object, recalculating them...")
      mm_tmp <- object
      mm_tmp@interpolated_Z <- list()
      mm_tmp <- interpolate_factors(mm_tmp, mm_tmp@covariates_warped[[g]])
      mat <- t(get_interpolated_factors(mm_tmp, only_mean = TRUE)[[g]]$mean)
      rm(mm_tmp)
    }
    mat[is.na(mat)] <- 0
    colnames(mat) <- factors_names(object)
    rownames(mat) <- samples_names(object)[[g]]
    mat[, factors]
  })
  names(Z_interpol) <- groups
  Y <- lapply(get_data(object, add_intercept = FALSE)[views], function(view) view[groups])
  Y <- lapply(Y, function(x) lapply(x,t))
  
  r2_GP <- lapply(groups, function(g) {
    tmp_Z <- sapply(views, function(m) { sapply(factors, function(k) {
      a <- sum((as.matrix(Y[[m]][[g]]) - tcrossprod(Z[[g]][,k], W[[m]][,k]))**2, na.rm = TRUE)
      b <- sum(Y[[m]][[g]]**2, na.rm = TRUE)
      return(1 - a/b)
    })
    })
    tmp_Z <- matrix(tmp_Z, ncol = length(views), nrow = length(factors))
    colnames(tmp_Z) <- views
    rownames(tmp_Z) <- factors
    
    tmp_GP <- sapply(views, function(m) { sapply(factors, function(k) {
      a <- sum((as.matrix(Y[[m]][[g]]) - tcrossprod(Z_interpol[[g]][,k], W[[m]][,k]))**2, na.rm = TRUE)
      b <- sum(Y[[m]][[g]]**2, na.rm = TRUE)
      return(1 - a/b)
    })
    })
    tmp_GP <- matrix(tmp_GP, ncol = length(views), nrow = length(factors))
    colnames(tmp_GP) <- views
    rownames(tmp_GP) <- factors
    
    return(tmp_GP * 100)
    # return(pmax(tmp_GP - tmp_Z,0))
  })
  names(r2_GP) <- groups
  
  r2_GP_df <- melt(
    lapply(r2_GP, function(x)
      melt(as.matrix(x), varnames = c("factor", "view"))
    ), id.vars=c("factor", "view", "value")
  )
  colnames(r2_GP_df)[ncol(r2_GP_df)] <- "group"
  r2_GP_df$factor <- factor(r2_GP_df$factor, levels = factors)
  r2_GP_df$group <- factor(r2_GP_df$group, levels = groups)
  r2_GP_df$view <- factor(r2_GP_df$view, levels = views)
  
  
  # Set R2 limits
  r2_Z <- calculate_variance_explained(object)
  if (!is.null(min_r2)) r2_GP_df$value[r2_GP_df$value<min_r2] <- 0.001
  min_r2 = 0
  if (!is.null(max_r2)) {
    r2_GP_df$value[r2_GP_df$value>max_r2] <- max_r2
  } else {
    max_r2 = max(max(Reduce(c,r2_Z$r2_per_factor)), max(r2_GP_df$value))
  }
  
  p1 <- ggplot(r2_GP_df, aes(x=.data[[x]], y=.data[[y]])) + 
    geom_tile(aes(fill=.data$value), color="black") +
    facet_wrap(as.formula(sprintf('~%s',split_by)), nrow=1) +
    labs(x="", y="", title="") +
    scale_fill_gradientn(colors=c("gray97","darkblue"), guide="colorbar", limits=c(min_r2,max_r2)) +
    guides(fill=guide_colorbar("Var. (%)")) +
    theme(
      axis.text.x = element_text(size=rel(1.0), color="black"),
      axis.text.y = element_text(size=rel(1.1), color="black"),
      axis.line = element_blank(),
      axis.ticks =  element_blank(),
      panel.background = element_blank(),
      strip.background = element_blank(),
      strip.text = element_text(size=rel(1.0))
    )
  
  if (!legend) p1 <- p1 + theme(legend.position = "none")
  
  # remove facet title
  if (length(unique(r2_GP_df[,split_by]))==1) p1 <- p1 + theme(strip.text = element_blank())
  
  if(!compare_total){
    return(p1)
  } else{
    list(p1  + ggtitle("smooth part"),
         plot_variance_explained(object, min_r2 = min_r2, max_r2 = max_r2,
                                 x= x, y=y, split_by=split_by, factors = factors) + ggtitle("total"))
  }
}
bioFAM/MOFA2 documentation built on June 12, 2024, 3:57 p.m.