R/simulation_model.R

############################ simulation_model.R ################################
# Functions for the simulation of the models

#' Get the target simulations in a MIDAS design-like or list format
#' Simulate a model for a specified set of perturbation
#' @param model_description A MRAmodel object that describes the model to be fitted
#' @param targets A matrix of perturbations to simulate, as generated by \link{getCombinationMatrix} or "all" to reproduce the perturbations used to build the model
#' A perturbation matrix is filled with 0 and 1, each row is a perturbation set and column names are used to determine the name of the perturbation.
#' @param readouts List of nodes to simulate. If "all", all the nodes measured to fit the model will be used. Only nodes actually measured or inhibited for the model can be simulated.
#' @param inhibition_effect A single value, a list of values or NA. Values in ]0, -inf] to use for the inhibition, representing the log2-fold change in activity of the node (alternatively, a value between 0 and 1 representing the fraction of activity remaining after inhibition compared to basal). If NA, the values fitted for the inhibition will be used, or -1 if an inhibition is requested for a node that was not inhibited in the experiment.
#' @param with_offset Whether the simulation should include the offset (fitted simulation) or not (real activity prediction)
#' @return A list that represents a MIDAS measure-like format with fields 'conditions' the matrix of perturbations provided as 'targets', 'bestfit' the simulation, and 'variants' a list of simulations for the alternative parameter sets from profile likelihood
# @seealso \code{\link{getCombinationMatrix}}
#' @family simulation
#' @export
simulateModel <- function(model_description, targets="all", readouts = "all", inhibition_effect=NA, with_offset=TRUE) {
  design = model_description$design
  nodes = model_description$structure$names
  
  # Get the experimental design constraints on the prediction capacity (index + 1 for the C++)
  # Only combinations of effectively applied perturbations can be simulated
  if (length(design$inhib_nodes)>0){
    inhibables = nodes[ 1 + unique(c(design$inhib_nodes)) ]
  } else { 
    inhibables = character() 
  }
  if (length(design$stim_nodes)>0){
    stimulables = nodes[ 1 + unique(c(design$stim_nodes)) ]
  } else {
    stimulables = character()
  }
  measurables = nodes[ 1 + unique(c(design$measured_nodes)) ]
  measID = 1 + unique(c(design$measured_nodes))
  perID = 1 + unique(c(design$measured_nodes, design$stim_nodes, design$inhib_nodes))
  
  # Get the names of the nodes in the network to stimulate and inhibit, and the matrix of the perturbation
  if (is.matrix(targets)) {
    # Already in matrix form
    inhibitors = gsub("i$", "", colnames(targets)[grep("i$", colnames(targets))] )
    stimulators = colnames(targets)[grep("i$", colnames(targets), invert=T)]
    target_matrix = targets
  } else if (targets == "all") {
    inhibitors = nodes[design$inhib_nodes + 1]
    stimulators = nodes[design$stim_nodes + 1]
    target_matrix = cbind(design$inhibitor, design$stimuli)
    if (length(inhibitors) > 0) {
        colnames(target_matrix) = c(paste0(inhibitors, "i"), stimulators)
    } else {
        colnames(target_matrix) = stimulators
    }
  } else if (is.list(targets)) { # TODO distinguish between numeric and character
    # List of perturbation giving nodes names in vectors, TODO
    stop("Providing a list of target is not an implemented method")
    target_names = unlist(targets)
    inhibitors = unique(c(inhibitors, gsub("i$", "", target_names[grep("i$", target_names)])))
    stimulators = unique(c(stimulators, target_names[grep("i$", target_names, invert=T)]))
    target_matrix = rep(0, length(c(stimulators, inhibitors)))
    colnames(target_matrix) = c(stimulators, paste0(inhibitors, "i"))
    
    for (combination in targets) {
      line = 1; # TODO
    }
  }
  if (ncol(target_matrix) > 0) {
      target_matrix = as.matrix(aggregate(target_matrix, by=as.data.frame(target_matrix), FUN=mean)[,1:ncol(target_matrix), drop=FALSE])
  }
  
  # Set the new experimental design that will be used for the simulation
  # Look for non existing perturbations and remove the correspoding lines and columns

  if ((length(inhibitors) == 0 || sum(!is.na(match(inhibitors, inhibables))) == 0) && (length(stimulators) == 0 || sum(!is.na(match(stimulators, stimulables))) == 0)) {
      stop("No valid perturbations to simulate")
  }
  ## Set the inhibition matrices
  if (length(inhibitors) > 0) {
    present_inh = inhibitors %in% inhibables
    if (any(!present_inh)){
      message(paste0(ifelse(sum(!present_inh)==1,"Node ","Nodes "), paste(inhibitors[!present_inh], collapse=" , "), " not inhibited in the network and won't be used\n"))
      drop_inhib = unique(unlist( lapply(paste0(inhibitors[!present_inh], "i"), function(cc){ which(target_matrix[,cc]==1) }) ))
      if (length(drop_inhib) > 0) {
        target_matrix = target_matrix[-drop_inhib,, drop=FALSE]
      }
    }
    inhib_nodes = inhibitors[present_inh]
    if (any(present_inh)) {
      inhibitions = target_matrix[, paste0(inhibitors[present_inh], "i"), drop=FALSE]
    } else {
      inhibitions = target_matrix[, c(), drop=FALSE]
    }
  }
  ## Set the stimuli matrices
  if (length(stimulators) > 0) {
    present_stim = stimulators %in% stimulables
    if (any(!present_stim)){
      message(paste0(ifelse(sum(!present_stim)==1,"Node ","Nodes "), paste(stimulators[!present_stim], collapse=" , "), " not stimulated in the network and won't be used"))
      drop_stim = unique(unlist( lapply(stimulators[!present_stim], function(cc){ which(target_matrix[,cc]==1) }) ))
      if (length(drop_stim) > 0) {
        target_matrix = target_matrix[-drop_stim,, drop=FALSE]
      }
    }
    stim_nodes = stimulators[present_stim]
    if (any(present_stim)) {
      stimulations = target_matrix[, stimulators[present_stim], drop=FALSE]
    } else {
      stimulations = target_matrix[, c(), drop=FALSE]
    }
  } else {
    stim_nodes = character()
    stimulations = matrix(nrow=nrow(inhibitions), ncol=0)
  }
  if (length(inhibitors) == 0) {
    inhib_nodes = character()
    inhibitions = matrix(nrow=nrow(stimulations), ncol=0)  
  } else {
      if (exists("drop_stim")) {
          inhibitions = inhibitions[-drop_stim,, drop=FALSE] # Remove the lines with unused stimulations from the inhibitions matrix
      }
  }

  if(length(stim_nodes)==0 && length(inhib_nodes)==0){
    stop("None of the perturbations provided correspond to nodes inhibited or stimulated in this network.")
  }
  # Generate the reduced target_matrix
  target_matrix = cbind(stimulations, inhibitions)
  target_matrix = aggregate(target_matrix, by=as.data.frame(target_matrix), FUN=mean)[,1:ncol(target_matrix), drop=FALSE]

  ## Set the nodes to be measured
  simulated_nodes = character()
  if (all(readouts == "all")) {
    simulated_nodes = measurables
  } else {
    for (node in readouts) {
      if (is.character(readouts)) {
        if (!(node %in% nodes)) {
          message(paste0("The node ", node, " is not in the network."))
        } else if (!(node %in% measurables)) {
          message(paste0("The node ", node, " cannot be measured with this model."))
        } else {
          simulated_nodes = c(simulated_nodes, which(nodes == node))
        }
      } else if (is.numeric(readouts)) { # Consider the R style numeration
        if (!(node %in% 1:length(nodes))) {
          message(paste0("There are only ", length(nodes), " node in the network."))
        } else if (!(node %in% measID)) {
          message(paste0("The node ", node, " (", nodes[node], ") cannot be measured with this model."))
        } else {
          simulated_nodes = c(simulated_nodes, node)
        }
      }
    }
  }
  simulated_nodes = unique(simulated_nodes)
  node_index = suppressWarnings(!is.na(as.numeric(simulated_nodes)))
  simulated_nodes[node_index] = model_description$structure$names[as.numeric(simulated_nodes[node_index])]
  if (length(simulated_nodes) == 0) {
    stop("None of the simulations required correspond to nodes measured in the network.")
  }
  simulated_index = sapply(simulated_nodes, function(nn) { which(model_description$structure$names==nn)-1 } ) # C++ index of the simulated_nodes
  simulated_cols = sapply(simulated_index, function(idx) { which(design$measured_nodes==idx) } ) # Index in result matrix column, ordered like simulated_nodes
  new_design = getExperimentalDesign(model_description$structure, stim_nodes, inhib_nodes, simulated_nodes, stimulations, inhibitions, model_description$basal)
  
  # Set up the model and the data for the simulation
  model = new(STASNet:::Model)
  model$setModel( new_design, model_description$structure )
  new_data = new(STASNet:::Data)
  new_data$set_unstim_data(matrix( rep(model_description$data$unstim_data[1,simulated_cols], nrow(target_matrix)), byrow=T, nrow=nrow(target_matrix) ))
  new_data$set_scale( matrix( rep(model_description$data$scale[1, simulated_cols], nrow(target_matrix)), byrow=T, nrow=nrow(target_matrix) ) )
  
  # Compute the predictions
  prediction = list()
  # Reset the target matrix to match the layout in new_design
  target_matrix = cbind(new_design$stimuli, new_design$inhibitor)
  if (length(inhib_nodes)>0) {
      i_inhib_nodes = paste0(inhib_nodes, "i")
  } else {
      i_inhib_nodes = c()
  }
  colnames(target_matrix) = c( stim_nodes, i_inhib_nodes )
  prediction$conditions = target_matrix
  
  ## Use the optimal fit
  old_inhib_nodes = model_description$structure$names[1+design$inhib_nodes]
  if (is.numeric(inhibition_effect)) {
    use_fitted = F
    inhib_values = inhibition_effect
  } else {
    use_fitted = T
    inhib_values = -1
  }
  
  new_params = getParametersForNewDesign(model, model_description$model, model_description$parameters, old_inhib_nodes, inhib_nodes, inhib_values, use_fitted)
  if (with_offset) {
      prediction$bestfit = model$simulateWithOffset(new_data, new_params)$prediction
  } else {
      prediction$bestfit = model$simulate(new_data, new_params)$prediction
  }
  colnames(prediction$bestfit) = simulated_nodes
  
  ## Parameters sets provided by the profile likelihood
  params_sets = list()
  if (length(model_description$param_range) == length(model_description$parameters)) {
    idx = 1
    for (i in 1:length(model_description$param_range)) {
      if (!is.na(model_description$param_range[[i]]$low_set[1])) {
        params_sets[[idx]] = getParametersForNewDesign(model, model_description$model, model_description$param_range[[i]]$low_set, old_inhib_nodes, inhib_nodes, inhib_values, use_fitted)
        idx = idx+1
      }
      if (!is.na(model_description$param_range[[i]]$high_set[1])) {
        params_sets[[idx]] = getParametersForNewDesign(model, model_description$model, model_description$param_range[[i]]$high_set, old_inhib_nodes, inhib_nodes, inhib_values, use_fitted)
        idx = idx+1
      }
    }
  }
  ### Predictions for the extra parameter sets
  prediction$variants = list()
  i=1
  for (params in params_sets) {
    if (with_offset) {
        prediction$variants = c(prediction$variants, list(model$simulateWithOffset(new_data, params)$prediction))
    } else {
        prediction$variants = c(prediction$variants, list(model$simulate(new_data, params)$prediction))
    }
    colnames(prediction$variants[[i]]) = simulated_nodes
    i=i+1
  }

  prediction$data = list()
  prediction$error = list()
  # Extract the data that correspond to the simulation
  # The code assumes that all perturbed nodes are perturbed in the data which should be guaranteed by the first part of the function
  if (exists("data", model_description)) {
#     sim_design = cbind(stimulations, inhibitions)
#     if (length(inhib_nodes) > 0) {
#         colnames(sim_design) = c( stim_nodes, paste0(inhib_nodes, "i"))
#     } else {
#         colnames(sim_design) = stim_nodes
#     }
      sim_design = target_matrix
      data_design = cbind(design$stimuli, design$inhibitor)
      if (length(design$inhib_nodes) > 0) {
          colnames(data_design) = c( model_description$structure$names[1+design$stim_nodes], paste0(model_description$structure$names[1+design$inhib_nodes], "i") )
      } else {
          colnames(data_design) = model_description$structure$names[1+design$stim_nodes]
      }
      # Extract from the data the lines where that correspond to the perturbation of the simulation, where no other simulation than those is applied
      common = colnames(data_design) %in% colnames(sim_design)
      valid_lines = which(apply(data_design, 1, function(drow){ all(drow[!common]==0) }))
      # Subset the columns in the data_design in the same order as the columns in the sim_design
      common = names(unlist(sapply( colnames(data_design), function(cd) { ifelse(cd %in% colnames(sim_design), return(which(colnames(data_design)==cd)), return(NULL)) } )))
      control_line = which(apply(sim_design, 1, function(srow){ all(srow==0) }))

      match_data = numeric()
      match_sim = numeric()
      for (sr in 1:nrow(sim_design)) {
        corresponding = sapply(valid_lines, function(dr){ all(data_design[dr, common] == sim_design[sr,common]) })
        if (any(corresponding)) {
          if (length(which(corresponding))>1) {
            stop("More than one 'corresponding' line found between simulated and original design")
          }
          match_data = c(match_data, valid_lines[which(corresponding)])
          match_sim = c(match_sim, sr)
        }
      }
      # Default to NA if the data are not present
      # Data field must exist, check for non emptiness for 'createSimulation' where data exist but are empty
      prediction$data = matrix( NA, ncol=ncol(prediction$bestfit), nrow=nrow(prediction$bestfit), dimnames=list(NULL, colnames(prediction$bestfit)) )
      prediction$error = prediction$data
      if (exists("stim_data", model_description$data) && nrow(model_description$data$stim_data) > 0) {
          prediction$data[match_sim,] = model_description$data$stim_data[match_data, simulated_cols]
          if (length(control_line)>0) {
              prediction$data[control_line,] = new_data$unstim_data[1,]
          }
      }
      if (exists("error", model_description$data) && nrow(model_description$data$error) > 0) {
          prediction$error[match_sim,] = model_description$data$error[match_data, simulated_cols]
          if (length(control_line)>0) {
              prediction$error[control_line,] = rep(0, ncol(prediction$error))
          }
      }
      prediction$unstim_data = matrix(rep(new_data$unstim_data[1,], nrow(prediction$bestfit)), nrow=nrow(prediction$bestfit), byrow=TRUE)
  }
  
  rm(model) # Free the memory
  return(prediction)
}

# Give a set of parameter usable by the new model from the parameters fitted in the old model
#' @family simulation
getParametersForNewDesign <- function(new_model, old_model, old_parameters, old_inhib, inhib_nodes, inhibition=-1, use_fitted_inhib=T) {
  # Get the adjacency matrix and the inhibitions values
  response = old_model$getLocalResponseFromParameter(old_parameters)
  inhib_values = numeric()
  for (inhibitor in inhib_nodes) {
    if (use_fitted_inhib && inhibitor %in% old_inhib) {
      inhib_values = c(inhib_values, response$inhibitors[which(old_inhib == inhibitor)])
    } else {
      # Possibility to define one inhibition for all or personnalised inhibitions
      if (length(inhibition) == length(inhib_nodes)) {
        cinh = inhibition[which(inhib_nodes == inhibitor)]
      } else {
        cinh = inhibition[1]
      }
      if (cinh > 1) {
        cinh = -cinh
      } else if (cinh > 0) {
        cinh = log2(cinh)
      }
      inhib_values = c(inhib_values, cinh)
    }
  }
  return(new_model$getParameterFromLocalResponse(response$local_response, inhib_values))
}

#' Create a perturbation matrix
#'
#' Create the perturbation matrix for a set of perturbations, building all n-combinations of stimulators with all m-combinations of inhibitors. Add the cases with only stimulations and only inhibitions.
#' @param perturbations A vector with the name of the perturbation, either NODE for a stimulation a node, or NODEi for its inhibition
#' @param inhib_combo Number of inhibitions to use simultaneously in each perturbation
#' @param stim_combo Number of stimulations to use simultaneously in each perturbation
#' @param byStim Whether the perturbations should be ordered according to the stimulations or the inhibitions
#' @return A perturbation matrix with the names of the nodes as columnnames
#' @export
#' @family simulation
#' @author Mathurin Dorel \email{dorel@@horus.ens.fr}
getCombinationMatrix <- function (perturbations, inhib_combo = 2, stim_combo = 1, byStim=T) {
  if (!is.numeric(inhib_combo)) { stop("'inhib_combo' must be numeric") }
  if (!is.numeric(stim_combo)) { stop("'stim_combo' must be numeric") }
  stimulators = perturbations[!grepl("i$", perturbations)]
  if (length(stimulators) > 0 && length(stimulators) > 0 && stim_combo > length(stimulators) ) {
    stop ("Not enough stimulations to build the combinations")
  }
  inhibitors = perturbations[grepl("i$", perturbations)]
  if (length(inhibitors) > 0 && inhib_combo > length(inhibitors) ) {
    stop ("Not enough inhibitions to build the combinations")
  }
  if (length(inhibitors) == 0 && length(perturbations) == 0) {
    stop("Inhibitions or stimulations must be provided to build the combination matrix")
  }
  
  if (length(inhibitors) > 0) {
    # Create the inhibition matrix
    inhib_combos = build_combo(seq(length(inhibitors)), inhib_combo, c())
    tmp = rep(0, length(inhibitors))
    inhib_matrix = c()
    for (i in 1:nrow(inhib_combos)) {
      inhib_matrix = rbind(inhib_matrix, tmp)
      inhib_matrix[i, inhib_combos[i,]] = 1
    }
    colnames(inhib_matrix) = inhibitors
    ## Put a line of 0 to get the stimulation alone
    inhib_matrix = rbind(rep(0, ncol(inhib_matrix)), inhib_matrix)
  }  

  if (length(stimulators) > 0) {
    # Create the stimulation matrix
    stim_combos = build_combo(seq(length(stimulators)), stim_combo, c())
    tmp = rep(0, length(stimulators))
    stim_matrix = c()
    for (i in 1:nrow(stim_combos)) {
      stim_matrix = rbind(stim_matrix, tmp)
      stim_matrix[i, stim_combos[i,]] = 1
    }
    colnames(stim_matrix) = stimulators
    ## Put a line of 0 to get the inhibition alone
    stim_matrix = rbind(rep(0, ncol(stim_matrix)), stim_matrix)
  } else {
    # Only use the inhibition/stimulation matrix if no stimulators/inhibitors are present
    rownames(inhib_matrix) = NULL
    return(inhib_matrix)
  }
  if (length(inhibitors) == 0) {
    rownames(stim_matrix) = NULL
    return(stim_matrix)
  }
  
  # Merge the two matrices to get all the combinations
  ## Combination of perturbations and inhibitions, classified by inhibitions or by stimulations
  perturbation_matrix = c()
  if (byStim) {
    for (i in 1:nrow(stim_matrix)) {
      perturbation_matrix = rbind(perturbation_matrix, cbind( matrix(rep(stim_matrix[i,], nrow(inhib_matrix)), nrow(inhib_matrix), ncol(stim_matrix) , byrow=T), inhib_matrix ))
    }
    colnames(perturbation_matrix) = c(stimulators, inhibitors)
  } else {
    for (i in 1:nrow(inhib_matrix)) {
      perturbation_matrix = rbind(perturbation_matrix, cbind( matrix(rep(inhib_matrix[i,], nrow(stim_matrix)), nrow(stim_matrix), ncol(inhib_matrix) , byrow=T), stim_matrix ))
    }
    colnames(perturbation_matrix) = c(inhibitors, stimulators)
  }
  rownames(perturbation_matrix) = NULL
  
  return(perturbation_matrix)
}

# Recursively build the n choose k combinations for a set
# @param symbols Names of the elements to combine
# @param remaining_steps Number of elements to add
# @param to_extend Matrix of combinations to extend
build_combo <- function (symbols, remaining_steps, to_extend) {
  if (remaining_steps <= 0) {
    return(to_extend)
  } else if (length(symbols) < remaining_steps) {
    return(c())
  }
  final = c()
  # Add each symbol and deepen the recursion with remaining symbols beyond the selected one
  for ( index in seq(length(symbols)) ) {
    extension = cbind(to_extend, symbols[index])
    extension = build_combo(symbols[ -(1:index) ], remaining_steps-1, extension)
    final = rbind(final, extension)
  }
  return(final)
}

#' Plot the predictions by the model
#'
#' @param model_description A MRAmodel object that describes the model to be fitted
#' @param targets A matrix of perturbations to simulate, as generated by \link{getCombinationMatrix} or "all" to reproduce the perturbations used to build the model
#' A perturbation matrix is filled with 0 and 1, each row is a perturbation set and column names are used to determine the name of the perturbation.
#' @param readouts List of nodes to simulate. If "all", all the nodes measured to fit the model will be used. Only nodes actually measured or inhibited for the model can be simulated.
#' @param inhibition_effect A single value, a list of values or NA. Values in ]0, -inf] to use for the inhibition, representing the log2-fold change in activity of the node (alternatively, a value between 0 and 1 representing the fraction of activity remaining after inhibition compared to basal). If NA, the values fitted for the inhibition will be used, or -1 if an inhibition is requested for a node that was not inhibited in the experiment.
#' @param log_axis Boolean, whether the ordinate axis should be in log scale
#' @param with_data Plot the data of the model next to the prediction
#' @param compare A list of MRAmodel to compare the predictions, targets and readouts must be valid for MRAmodels
#' @return A list that represents a MIDAS measure-like format with fields 'conditions' the matrix of perturbations provided as 'targets', 'bestfit' the simulation, and 'variants' a list of simulations for the alternative parameter sets from profile likelihood
#' @export
#' @return Invisibly, the matrix of the results of the simulation
#' @author Mathurin Dorel \email{mathurin.dorel@@charite.de}
plotModelSimulation <- function(model_description, targets="all", readouts = "all", inhibition_effect=NA, log_axis=TRUE, with_data=FALSE, compare=list()) {
    invisible(plotSimulation( simulateModel(model_description, targets, readouts, inhibition_effect), log_axis, compare=lapply(compare, simulateModel, targets, readouts, inhibition_effect), with_data=with_data ))
}

#' Plot predictions generated by simulateModel
#' One plot per measured node
#' Plot with error bars if available, and give absolute value
#or log-fold change
#' @param prediction A list of the model predictions as produced by \link{simulateModel}
#' @param log_axis Boolean, whether the ordinate axis should be in log scale
#' @param with_data Display 2 bars per condition, one for the simulation and one for the data
#' @param data_color Color of the bars corresponding to the data
#' @param sim_colors Vector of colors for the bars corresponding to the simulations. The first color is used for the main prediction and the next for the predictions in 'compare', recycled as necessary for all predictions.
#' @param compare A list of predictions to compare to. A predition is a list with fields 'bestfit', 'conditions', 'error', 'data', 'unstim_data' and 'variants' as generated by simulateModel
#' @param strict When 'compare' is not empty, whether the other prediction conformity must be strictly evaluated (presence and equality of the field 'condition', equal name of columns in field 'bestfit') or not (the 'bestfit' length is the only thing that is checked)
#' @return Invisibly, the matrix of the results of the simulation
#' @export
#' @seealso getCombinationMatrix, simulateModel
#' @author Mathurin Dorel \email{mathurin.dorel@@charite.de}
#' @family simulation
# TODO , plotsPerFrame = 4
# @param maxPlotsPerFrame Maximum number of perturbation per frame
plotSimulation <- function(prediction, log_axis=FALSE, with_data=TRUE, data_color=cbbPalette[1], sim_colors=cbbPalette[-1], compare=list(), strict=TRUE) {
  colors = sim_colors[1]
  color_idx = 2
  legend = c("simulation")
  if (with_data && length(prediction$data) > 0) {
      colors = c(data_color, sim_colors[1])
      legend = c("data", "simulation")
      if (length(prediction$error) > 0 && sum(!is.na(prediction$error)) > 0) {
          with_variation = TRUE
      }
  } else {
      with_data = FALSE
      with_variation = FALSE
  }
  compare_data = list()
  if (length(compare) > 0) {
      for (pidx in length(compare)) {
          pred = compare[[pidx]]
          if (all(dim(pred$bestfit)==dim(prediction$bestfit)) && (!strict || (all(pred$conditions==prediction$conditions) && colnames(prediction$bestfit)==colnames(pred$bestfit) )) ) {
              colnames(pred$bestfit) = colnames(prediction$bestfit)
              compare_data[[length(compare_data)+1]] = pred$bestfit
              legend = c(legend, names(compare_data)[pidx])
              colors = c(colors, sim_colors[color_idx]) # TODO define a gradient or something
              color_idx = (color_idx + 1) %% length(sim_colors)
          } else {
              warning(paste0("The prediction ", pidx, " was not used because it is invalid"))
          }
      }
  }
  if (!is.list(prediction) || !is.matrix(prediction$bestfit) ) {
    stop("Invalid type for argument 'prediction', see 'simulateModel' return value")
  }

  ratio = 2/3 # Display height ratio between the plot and its annotation
  layout(matrix(1:2, nrow=2, byrow=T), heights=c(ratio, 1-ratio))
  old_mar = par()$mar
  # plot the data
  for (node in 1:ncol(prediction$bestfit)) {
    # Collects the positions of the bars
    par(mar = c(1, 6, 4, 4))
    if (length(compare_data) > 0) {
        compare_row = t(sapply(compare_data, function(cd){cd[,node]}))
    } else {
        compare_row = c()
    }
    if (with_data) {
        to_plot = rbind(prediction$data[,node], prediction$bestfit[,node], compare_row)
        bars = barplot(to_plot, plot=F, beside=TRUE)
        sim_bars = bars[2,]
        data_bars = bars[1,]
        bars = colMeans(bars)
    } else if (length(compare_data) > 0) {
        to_plot = rbind(prediction$bestfit[,node], compare_row)
        bars = barplot(to_plot, plot=F, beside=TRUE)
        sim_bars = bars[1,]
        bars = colMeans(bars)
    } else {
        to_plot = cbind(prediction$bestfit[,node])
        bars = barplot(to_plot, plot=F, beside=TRUE)
        sim_bars = bars
    }
    if (log_axis) { to_plot = log(to_plot, 10) }
    limits = c(ifelse(log_axis, 1, 0), ifelse(log_axis,1.1,1.5) * max(c(to_plot), na.rm=TRUE)) # Expect values > 1
    if (length(prediction$variants) > 0) {
      low_var = numeric()
      high_var = numeric()
      # Collect the extreme values for each condition, and the global extremes to be sure everything gets included in the plot
      for (perturbation in 1:nrow(prediction$bestfit)) {
        variants = c(prediction$bestfit[perturbation, node])
        if (with_data) { c(variants, prediction$data[perturbation, node]) }
        for (set in 1:length(prediction$variants)) {
          variants = c(variants, prediction$variants[[set]][perturbation, node])
        }
        low_var = c(low_var, sort(variants)[1])
        limits[1] = min(limits[1], low_var, na.rm=TRUE)
        # If the inaccuracy yields negative activity, we correct if log scale is used
        if (low_var <= 0 && log_axis) { low_var = 0.000001 }
        high_var = c(high_var, sort(variants, decreasing=T)[1])
      }
      # Plot the bars with the errors
      entity = colnames(prediction$bestfit)[node]
      if (log_axis){
        barplot(to_plot, ylim=limits, ylab=paste(entity, "log activity (AU)"), col=colors, las=1, main=entity, beside=TRUE)  
      }else{
        barplot(to_plot, ylim=limits, ylab=paste(entity, "activity (AU)"), col=colors, main=entity, beside=TRUE)  
      }
      
      text_pos = limits[2] - 0.1 * limits[2]
      if (log_axis) {
          low_var = log(low_var, 10)
          high_var = log(high_var, 10)
      }
      # Write the value if the bar goes outside the plotting frame
      segments( sim_bars, low_var, sim_bars, sapply(high_var, function(X){ ifelse(X>limits[2], text_pos, X) }) )
      text( sim_bars, text_pos, sapply(high_var, function(X){ ifelse(X>limits[2],ifelse(X<100000,round(X),signif(X,1)), "") }), pos=2, srt=90,offset=0.2 )
      space = abs(sim_bars[2] - sim_bars[1])/(3*(2+length(compare)+ifelse(with_data, 1, 0)))
      segments(sim_bars - space, low_var, sim_bars + space, low_var)
      in_lim=high_var<=limits[2]
      segments(sim_bars[in_lim] - space, high_var[in_lim], sim_bars[in_lim] + space, high_var[in_lim])
    } else {
      entity = colnames(prediction$bestfit)[node]
      if (log_axis){
        barplot(to_plot, ylab=paste(entity, "log activity (AU)"), col=colors, las=1, main=entity, beside=TRUE)  
      }else{
        barplot(to_plot, ylab=paste(entity, "activity (AU)"), col=colors, main=entity, beside=TRUE)  
      }
      low_var=0;
      limits[1] = 1.2 * max(prediction$bestfit[,node], na.rm=TRUE)
    }
    
    # Write the conditions used
    par(mar = c(0, 6, 0, 4), xpd=NA)
    pert_name_x = ifelse(length(bars)>1, bars[1]-(bars[2]-bars[1])/2, 0)
    for (pert in 1:ncol(prediction$conditions)) {
      legend_line = rep("-", nrow(prediction$conditions))
      legend_line[prediction$conditions[, pert] == 1] = "+"
      y_coord = min(limits[1], low_var, na.rm=TRUE) - (pert+1) * limits[2] * 0.9 * (1-ratio) / (ncol(prediction$conditions)+1)
      text(bars, y_coord, legend_line)
      text(pert_name_x, y_coord, colnames(prediction$conditions)[pert], pos=2)
    }
    STASNet:::eplot( xlim=c(0, 1), ylim=c(0, 1) )
  }
  par(mar=old_mar, xpd=T)
  layout(1)
  
  # plot the legend in a separate plot
  eplot(c(0, 1),c(0, 1))
  legend("center", legend=legend, col=colors, lwd=2, horiz =F, bty="n")
  
  # Invisibly returns the prediction
  return(invisible(prediction))
  
  # TODO add error multiplier
}

#' Simulate the model for the experimental design used for the fitting
#' @param mra_model The MRAmodel to simulate
#' @return A matrix containing the simulation with the names of the measured nodes as column names
getSimulation <- function(mra_model, with_offset=TRUE) {
  if (with_offset) {
      prediction = mra_model$model$simulateWithOffset(mra_model$data, mra_model$parameters)$prediction
  } else {
      prediction = mra_model$model$simulate(mra_model$data, mra_model$parameters)$prediction
  }
  colnames(prediction) = getMeasuredNodesNames(mra_model)
  return(prediction)
}

# Plots an empty zone, useful to write only text
eplot <- function(xlim, ylim, ...) {
  plot(1, xlim=xlim, ylim=ylim, type="n", axes=F, xlab=NA, ylab=NA, ...)
}
MathurinD/STASNet documentation built on May 28, 2019, 1:50 p.m.