R/plot.R

vcols1 <- rev(c("#440154FF", "#440256FF", "#450457FF", "#450559FF", "#46075AFF",
           "#46085CFF", "#460A5DFF", "#460B5EFF", "#470D60FF", "#470E61FF",
           "#471063FF", "#471164FF", "#471366FF", "#481567FF", "#481668FF",
           "#481769FF", "#48196BFF", "#481B6DFF", "#481B6DFF", "#481C6EFF",
           "#481E70FF", "#482070FF", "#482072FF", "#482274FF", "#482475FF",
           "#482576FF", "#482677FF", "#482778FF", "#482979FF", "#472A7AFF",
           "#472B7AFF", "#472D7BFF", "#472E7CFF", "#472F7DFF", "#46307EFF",
           "#46327EFF", "#46337FFF", "#463480FF", "#453581FF", "#453781FF",
           "#453882FF", "#443983FF", "#443A83FF", "#443B84FF", "#433D84FF",
           "#433E85FF", "#423F85FF", "#424086FF", "#424186FF", "#414287FF",
           "#414487FF", "#404588FF", "#404688FF", "#3F4788FF", "#3F4889FF",
           "#3E4989FF", "#3E4B89FF", "#3E4C8AFF", "#3D4D8AFF", "#3D4E8AFF",
           "#3C4F8AFF", "#3C508BFF", "#3B518BFF", "#3A538BFF", "#3A548CFF",
           "#39558CFF", "#39558CFF", "#38578CFF", "#38598CFF", "#375A8CFF",
           "#375B8DFF", "#365C8DFF", "#365D8DFF", "#355E8DFF", "#355F8DFF",
           "#34608DFF", "#34618DFF", "#33628DFF", "#33638DFF", "#32648EFF",
           "#32658EFF", "#31668EFF", "#31678EFF", "#31688EFF", "#30698EFF",
           "#306A8EFF", "#2F6B8EFF", "#2F6C8EFF", "#2E6D8EFF", "#2E6E8EFF",
           "#2E6F8EFF", "#2D708EFF", "#2D718EFF", "#2C718EFF", "#2C728EFF",
           "#2C738EFF", "#2B748EFF", "#2B758EFF", "#2A768EFF", "#2A778EFF",
           "#2A788EFF", "#29798EFF", "#297A8EFF", "#287B8EFF", "#287D8EFF",
           "#287E8EFF", "#277F8EFF", "#27808EFF", "#26818EFF", "#26828EFF",
           "#26828EFF", "#25838EFF", "#25848EFF", "#25858EFF", "#24868EFF",
           "#24878EFF", "#23888EFF", "#23898EFF", "#238A8DFF", "#228B8DFF",
           "#228C8DFF", "#228D8DFF", "#218E8DFF", "#218F8DFF", "#21908DFF",
           "#21918CFF", "#20928CFF", "#20928CFF", "#20938CFF", "#1F948CFF",
           "#1F958BFF", "#1F968BFF", "#1F978BFF", "#1F988BFF", "#1F998AFF",
           "#1F9A8AFF", "#1E9B8AFF", "#1E9C89FF", "#1E9D89FF", "#1F9E89FF",
           "#1F9F88FF", "#1FA088FF", "#1FA188FF", "#1FA187FF", "#20A287FF",
           "#20A386FF", "#21A586FF", "#21A685FF", "#22A785FF", "#22A885FF",
           "#23A983FF", "#24AA83FF", "#25AB82FF", "#25AC82FF", "#26AD81FF",
           "#27AD81FF", "#28AE80FF", "#29AF7FFF", "#2AB07FFF", "#2CB17EFF",
           "#2DB27DFF", "#2EB37CFF", "#2FB47CFF", "#31B57BFF", "#32B67AFF",
           "#34B679FF", "#35B779FF", "#37B878FF", "#38B977FF", "#3ABA76FF",
           "#3BBB75FF", "#3DBC74FF", "#3FBC73FF", "#40BD72FF", "#42BE71FF",
           "#44BF70FF", "#46C06FFF", "#49C16EFF", "#4BC16DFF", "#4DC26CFF",
           "#4FC36BFF", "#51C46AFF", "#53C569FF", "#55C568FF", "#57C666FF",
           "#59C864FF", "#5BC864FF", "#5DC863FF", "#5FCA61FF", "#62CB5FFF",
           "#64CB5FFF", "#66CB5DFF", "#68CD5BFF", "#6BCD5AFF", "#6DCE59FF",
           "#6FCF57FF", "#72D056FF", "#75D054FF", "#76D153FF", "#79D151FF",
           "#7CD250FF", "#7FD34EFF", "#81D34DFF", "#84D44BFF", "#86D549FF",
           "#89D548FF", "#8BD646FF", "#8ED645FF", "#90D743FF", "#93D741FF",
           "#95D840FF", "#98D83EFF", "#9BD93CFF", "#9DD93BFF", "#A0DA39FF",
           "#A2DA37FF", "#A6DB36FF", "#A8DB34FF", "#ABDC31FF", "#AEDC30FF",
           "#B1DD2EFF", "#B3DD2CFF", "#B6DE2AFF", "#B9DE29FF", "#BBDE27FF",
           "#BEDF26FF", "#C1DF24FF", "#C3DF22FF", "#C7E020FF", "#C9E020FF",
           "#CCE11EFF", "#CFE11CFF", "#D1E21BFF", "#D4E21AFF", "#D7E219FF",
           "#D9E319FF", "#DCE318FF", "#DEE318FF", "#E1E418FF", "#E4E419FF",
           "#E7E419FF", "#E9E51AFF", "#ECE51BFF", "#EFE51CFF", "#F1E51DFF",
           "#F4E61EFF", "#F6E620FF", "#F8E621FF", "#FBE723FF", "#FDE725FF"))

vcols <- c("#FCFDBFFF", "#FCFBBDFF", "#FCF9BBFF", "#FCF7B9FF", "#FCF6B8FF",
           "#FCF4B6FF", "#FCF2B4FF", "#FCF0B2FF", "#FCEEB0FF", "#FCECAEFF",
           "#FDEBABFF", "#FDE8AAFF", "#FDE6A9FF", "#FDE4A6FF", "#FDE3A4FF",
           "#FDE1A2FF", "#FDDFA1FF", "#FDDD9FFF", "#FDDB9DFF", "#FDD99BFF",
           "#FED89AFF", "#FED698FF", "#FED496FF", "#FED294FF", "#FED093FF",
           "#FECE91FF", "#FECC8FFF", "#FECB8EFF", "#FEC98CFF", "#FEC78BFF",
           "#FEC588FF", "#FEC287FF", "#FEC185FF", "#FEBF84FF", "#FEBD82FF",
           "#FEBB81FF", "#FEB97FFF", "#FEB77EFF", "#FEB67CFF", "#FEB47BFF",
           "#FEB27AFF", "#FEB078FF", "#FEAE77FF", "#FEAC76FF", "#FEAA74FF",
           "#FEA973FF", "#FEA772FF", "#FEA571FF", "#FEA36FFF", "#FEA16EFF",
           "#FE9F6DFF", "#FE9C6CFF", "#FD9B6BFF", "#FD9A6AFF", "#FD9769FF",
           "#FD9568FF", "#FD9367FF", "#FD9166FF", "#FC8F65FF", "#FC8D64FF",
           "#FC8B63FF", "#FC8A61FF", "#FC8861FF", "#FB8661FF", "#FB845FFF",
           "#FA825FFF", "#FA805EFF", "#FA7E5EFF", "#F97C5DFF", "#F97A5DFF",
           "#F9785DFF", "#F8775CFF", "#F8755CFF", "#F7725CFF", "#F7705CFF",
           "#F66E5CFF", "#F66C5CFF", "#F56B5CFF", "#F4695CFF", "#F4675CFF",
           "#F3655CFF", "#F2645CFF", "#F2625DFF", "#F1605DFF", "#F05F5EFF",
           "#EF5D5EFF", "#EE5B5EFF", "#ED5A5FFF", "#EC5860FF", "#EB5760FF",
           "#EA5661FF", "#E95462FF", "#E85362FF", "#E75263FF", "#E55064FF",
           "#E44F64FF", "#E34E65FF", "#E14D66FF", "#E04B67FF", "#DF4A68FF",
           "#DD4968FF", "#DC4869FF", "#DA476AFF", "#D9456CFF", "#D7456CFF",
           "#D5456CFF", "#D4436EFF", "#D2426FFF", "#D1416FFF", "#CF4070FF",
           "#CE4071FF", "#CC3F71FF", "#CB3E72FF", "#C83E73FF", "#C73D73FF",
           "#C53C74FF", "#C43C75FF", "#C23B75FF", "#C03A76FF", "#BF3A77FF",
           "#BD3977FF", "#BC3978FF", "#BA3878FF", "#B83779FF", "#B73779FF",
           "#B5367AFF", "#B3367AFF", "#B2357BFF", "#B0357BFF", "#AE347BFF",
           "#AD347CFF", "#AB337CFF", "#AA337DFF", "#A8327DFF", "#A6317DFF",
           "#A5317EFF", "#A2307EFF", "#A1307EFF", "#9F2F7FFF", "#9D2F7FFF",
           "#9C2E7FFF", "#9A2E7FFF", "#992D80FF", "#972D80FF", "#952C80FF",
           "#942C80FF", "#922B81FF", "#902A81FF", "#8F2A81FF", "#8D2981FF",
           "#8B2981FF", "#8A2881FF", "#882781FF", "#872781FF", "#842681FF",
           "#832681FF", "#812581FF", "#802582FF", "#7E2482FF", "#7C2382FF",
           "#7B2382FF", "#792282FF", "#782281FF", "#762181FF", "#752181FF",
           "#732081FF", "#721F81FF", "#701F81FF", "#6E1E81FF", "#6D1D81FF",
           "#6B1D81FF", "#6A1C81FF", "#681C81FF", "#671B80FF", "#651A80FF",
           "#641A80FF", "#611980FF", "#601880FF", "#5E187FFF", "#5D177FFF",
           "#5B167FFF", "#5A167EFF", "#58157EFF", "#57157EFF", "#55137DFF",
           "#53137DFF", "#52137CFF", "#50127BFF", "#4E117BFF", "#4D117BFF",
           "#4B1079FF", "#491078FF", "#481078FF", "#461077FF", "#440F76FF",
           "#430F75FF", "#400F74FF", "#3F0F72FF", "#3E0F71FF", "#3B0F70FF",
           "#390F6EFF", "#38106CFF", "#36106BFF", "#341069FF", "#331067FF",
           "#311165FF", "#2F1163FF", "#2D1161FF", "#2C115FFF", "#2A115CFF",
           "#29115AFF", "#271258FF", "#251255FF", "#241253FF", "#221150FF",
           "#21114EFF", "#20114AFF", "#1E1149FF", "#1D1146FF", "#1B1043FF",
           "#1A1041FF", "#19103EFF", "#170F3CFF", "#160F3AFF", "#150E37FF",
           "#140E35FF", "#130D33FF", "#120D30FF", "#110B2EFF", "#0F0B2CFF",
           "#0D0B2AFF", "#0C0927FF", "#0B0925FF", "#0A0823FF", "#090721FF",
           "#08071FFF", "#07061DFF", "#06051BFF", "#060519FF", "#050416FF",
           "#040414FF", "#030312FF", "#03030FFF", "#02020DFF", "#02020BFF",
           "#020109FF", "#010108FF", "#010106FF", "#010005FF", "#000004FF")

#' @title Interaction Plot for an "mp" Class Object.
#'
#' @description This function plots interaction between received treatment and recommended treatment,
#' which provides an estimate of treatment effect of the identified subgroup.
#'
#' @details In the interaction plot, each point is the group mean given a received treatment
#' and a recommended treatment. Although usually
#' overestimating treatment effect in training set, interaction plots provides a sanity check for treatment
#' recommendation rules. Given a specific index of penalty parameter, the function
#' plots corresponding interaction plots.
#'
#' @param x A fitted "mp" class object returned by \code{mpersonalzied} function
#' @param penalty_index The index of penalty parameter configuration in \code{mp$penalty_parameter_sequence}.
#' When \code{mp$penalty = "none"}, \code{penalty_index} is automatically set to be 1.
#' @param ... not used
#'
#' @import ggplot2 gridExtra
#' @return A list object with each element as the interaction plots for a penalty parameter configuration.
#'
#' @examples
#' set.seed(123)
#' sim_dat  = simulated_dataset(n = 200, problem = "meta-analysis")
#' Xlist = sim_dat$Xlist; Ylist = sim_dat$Ylist; Trtlist = sim_dat$Trtlist
#'
#' # fit different rules with SGL penalty for this meta-analysis problem
#' mp_mod_diff = mpersonalized(problem = "meta-analysis",
#'                             Xlist = Xlist, Ylist = Ylist, Trtlist = Trtlist,
#'                             penalty = "lasso", single_rule = FALSE)
#'
#' # interaction plot of the 5th penalty parameter
#' plots = plot(x = mp_mod_diff, penalty_index = 5)
#' set.seed(NULL)
#' @export
plot.mp = function(x, penalty_index, ...)
{

  mp <- x
  Ylist = mp$Ylist
  Trtlist = mp$Trtlist
  Plist = mp$Plist
  single_rule = mp$single_rule
  q = mp$number_studies_or_outcomes
  penalty = mp$penalty
  problem = mp$problem

  recommend <- received <- NULL


  if (penalty == "none"){

    pred = predict(mp)$opt_treatment[[1]]

  } else {

    if (missing(penalty_index))
      stop("For penalty not equal to 'none', penalty_index must be inputed!")

    pred = predict(mp)$opt_treatment[[penalty_index]]

  }


  #group 1 defined as receive 1 and recommend 1; group 2 as receive 0 and recommend 1
  #group 3 defined as receive 0 and recommend 0; group 4 as receive 1 and recommend 0
  meanlist = mapply(function(Trt, pred, Y,  P) c(sum(Y * as.numeric(Trt == 1 & pred == 1) / P) / sum(as.numeric(Trt == 1 & pred == 1) / P),
                                                 sum(Y * as.numeric(Trt == 0 & pred == 1) / (1 - P)) / sum(as.numeric(Trt == 0 & pred == 1) / (1 - P)),
                                                 sum(Y * as.numeric(Trt == 0 & pred == 0) / (1 - P)) / sum(as.numeric(Trt == 0 & pred == 0) / (1 - P)),
                                                 sum(Y * as.numeric(Trt == 1 & pred == 0) / P) / sum(as.numeric(Trt == 1 & pred == 1) / P)),
                    Trt = Trtlist, pred = pred, Y = Ylist, P = Plist, SIMPLIFY = FALSE)

  plotlist = replicate(q, list())
  for (i in 1:q){
    plot_dat = data.frame(mean = meanlist[[i]], recommend = as.factor(c(1, 1, 0, 0)),
                          received = as.factor(c(1, 0, 0, 1)))
    plotlist[[i]] = ggplot(data = plot_dat, aes(y = mean, x = recommend, group = received)) +
      geom_line(aes(color = received), size = 2, alpha = 0.4) + geom_point(size = 3, aes(color = received))

    if (problem == "meta-analysis")
      plotlist[[i]] = plotlist[[i]] + ggtitle(label = paste("Study ",i))

    if (problem == "multiple outcomes")
      plotlist[[i]] = plotlist[[i]] + ggtitle(label = paste("Outcome ",i))
  }

  return(plotlist)
}


#' @title Interaction Plot for an "mp_cv" Class Object.
#'
#' @description This function plots interaction between received treatment and recommended treatment,
#' given the optimal penalty parameter.
#'
#' @param x A fitted 'mp_cv' class object returned by \code{mpersonalzied_cv} function
#' @param ... not used
#'
#' @import ggplot2 gridExtra
#' @return A list object representing the interaction plots for the optimal penalty parameter configuration.
#' Specifically, \eqn{k}th element is the interaction plot for the \eqn{k}th study/outcome.
#'
#' @examples
#' set.seed(123)
#' sim_dat  = simulated_dataset(n = 200, problem = "meta-analysis")
#' Xlist = sim_dat$Xlist; Ylist = sim_dat$Ylist; Trtlist = sim_dat$Trtlist
#'
#' # fit different rules with lasso penalty for this meta-analysis problem
#' mp_cvmod_diff = mpersonalized_cv(problem = "meta-analysis",
#'                                  Xlist = Xlist, Ylist = Ylist, Trtlist = Trtlist,
#'                                  penalty = "lasso", single_rule = FALSE)
#'
#' plots = plot(x = mp_cvmod_diff)
#' set.seed(NULL)
#' @export
plot.mp_cv = function(x, ...) {

  mp_cv <- x
  Ylist = mp_cv$Ylist
  Trtlist = mp_cv$Trtlist
  Plist = mp_cv$Plist
  q = mp_cv$number_studies_or_outcomes
  problem = mp_cv$problem

  recommend <- received <- NULL

  pred = predict(mp_cv)$opt_treatment

  #group 1 defined as receive 1 and recommend 1; group 2 as receive 0 and recommend 1
  #group 3 defined as receive 0 and recommend 0; group 4 as receive 1 and recommend 0
  meanlist = mapply(function(Trt, pred, Y,  P) c(sum(Y * as.numeric(Trt == 1 & pred == 1) / P) / sum(as.numeric(Trt == 1 & pred == 1) / P),
                                                 sum(Y * as.numeric(Trt == 0 & pred == 1) / (1 - P)) / sum(as.numeric(Trt == 0 & pred == 1) / (1 - P)),
                                                 sum(Y * as.numeric(Trt == 0 & pred == 0) / (1 - P)) / sum(as.numeric(Trt == 0 & pred == 0) / (1 - P)),
                                                 sum(Y * as.numeric(Trt == 1 & pred == 0) / P) / sum(as.numeric(Trt == 1 & pred == 1) / P)),
                    Trt = Trtlist, pred = pred, Y = Ylist, P = Plist, SIMPLIFY = FALSE)

  plotlist = replicate(q, list())
  for (i in 1:q){
    plot_dat = data.frame(mean = meanlist[[i]], recommend = as.factor(c(1, 1, 0, 0)),
                          received = as.factor(c(1, 0, 0, 1)))
    plotlist[[i]] = ggplot(data = plot_dat, aes(y = mean, x = recommend, group = received)) +
      geom_line(aes(color = received), size = 2, alpha = 0.4) + geom_point(size = 3, aes(color = received))

    if (problem == "meta-analysis")
      plotlist[[i]] = plotlist[[i]] + ggtitle(label = paste("Study ",i))

    if (problem == "multiple outcomes")
      plotlist[[i]] = plotlist[[i]] + ggtitle(label = paste("Outcome ",i))
  }

  return(plotlist)
}


#' @title Cross Validation Error Plot for an "mp_cv" Class Object.
#'
#' @description This function plots the cross validation error as a function of the tuning parameters.
#' For penalties with 2 tuning parameters, a heat map will be plotted via the \code{image()} function
#'
#' @param mp_cv A fitted 'mp_cv' class object returned by \code{mpersonalized_cv} function
#' @param col.regions color scale. See \code{\link[Matrix]{image-methods}}
#' @param key.lab label for colorkey
#' @param ... arguments to be passed to \code{\link[lattice]{levelplot}}
#'
#' @return Nothing
#'
#' @examples
#' set.seed(123)
#' sim_dat  = simulated_dataset(n = 200, problem = "meta-analysis")
#' Xlist = sim_dat$Xlist; Ylist = sim_dat$Ylist; Trtlist = sim_dat$Trtlist
#'
#' # fit different rules with lasso penalty for this meta-analysis problem
#' mp_cvmod_diff = mpersonalized_cv(problem = "meta-analysis",
#'                                  Xlist = Xlist, Ylist = Ylist, Trtlist = Trtlist,
#'                                  penalty = "lasso", single_rule = FALSE)
#'
#' plots = plotCVE(mp_cvmod_diff)
#' set.seed(NULL)
#' @export
#' @importFrom grDevices topo.colors
plotCVE <- function(mp_cv,
                    col.regions = vcols,
                    key.lab = "CV Err",
                    ...)
{
  if (class(mp_cv) != "mp_cv") stop("object supplied must be an 'mp_cv' object as returned by 'mpersonalized_cv()'")

  dim_tune <- dim(mp_cv$cv_error)

  if (is.null(dim_tune))
  {
    plot(y = mp_cv$cv_error, type = "b", x = mp_cv$penalty_parameter_sequence,
         xlab = expression(lambda), ylab = "Cross Validation Error")
  } else
  {
    if (mp_cv$penalty == "SGL+SL")
    {
      xlab <- expression(tau[0])
    } else
    {
      xlab <- expression(lambda[2])
    }

    rn <- round(unique(mp_cv$penalty_parameter_sequence[,1]), 2) #gsub("[^0-9\\.]", "", rownames(mp_cv$cv_error))
    cn <- round(unique(mp_cv$penalty_parameter_sequence[,2]), 2) #gsub("[^0-9\\.]", "", colnames(mp_cv$cv_error))

    # topo.colors(250)
    ylab <- expression(lambda[1])
    image(as(mp_cv$cv_error, "Matrix"), col.regions = col.regions, colorkey = TRUE,
          xlab = xlab, ylab = ylab, scales = list(y = list(labels = rn, at = 1:length(rn)),
                                                  x = list(labels = cn, at = 1:length(cn),
                                                           rot = 45)),
          ylab.right = key.lab,
          sub = NULL,
          ...)

    # grid.edit("[.]colorkey.labels$", grep=TRUE, just="right",
    #           global=T, x=unit(0.95, "npc"))
  }
}
chenshengkuang/mpersonalized documentation built on May 28, 2019, 7:16 p.m.