todo-files/plotggVIS.R

#' Plot filter values using ggvis.
#'
#' @family plot
#' @family filter
#'
#' @param fvalues ([FilterValues])\cr
#'   Filter values.
#' @param feat.type.cols (`logical(1)`)\cr
#'   Colors for factor and numeric features.
#'   `FALSE` means no colors.
#'   Default is `FALSE`.
#' @template ret_ggv
#' @export
#' @examples \dontrun{
#' fv = generateFilterValuesData(iris.task, method = "variance")
#' plotFilterValuesGGVIS(fv)
#' }
plotFilterValuesGGVIS = function(fvalues, feat.type.cols = FALSE) {
  requirePackages("_ggvis")
  assertClass(fvalues, classes = "FilterValues")
  if (!(is.null(fvalues$method))) {
    stop("fvalues must be generated by generateFilterValuesData, not getFilterValues, which is deprecated.")
  }

  data = fvalues$data
  data = setDF(melt(as.data.table(data), id.vars = c("name", "type"), variable = "method"))

  createPlot = function(data, feat.type.cols) {
    if (feat.type.cols) {
      plt = ggvis::ggvis(data, ggvis::prop("x", as.name("name")),
        ggvis::prop("y", as.name("value")),
        ggvis::prop("fill", as.name("type")))
    } else {
      plt = ggvis::ggvis(data, ggvis::prop("x", as.name("name")),
        ggvis::prop("y", as.name("value")))
    }

    plt = ggvis::layer_bars(plt)
    plt = ggvis::add_axis(plt, "y", title = "")
    plt = ggvis::add_axis(plt, "x", title = "")
    return(plt)
  }

  genPlotData = function(data, sort_type, value_column, factor_column, n_show) {
    if (sort_type != "none") {
      data = head(sortByCol(data, "value", FALSE), n = n_show)
      data[[factor_column]] = factor(data[[factor_column]],
        levels = data[[factor_column]][order(data[[value_column]],
          decreasing = sort_type == "decreasing")])
    }
    data
  }

  requirePackages("_shiny")
  header = shiny::headerPanel(sprintf("%s (%i features)", fvalues$task.desc$id, sum(fvalues$task.desc$n.feat)))
  method.input = shiny::selectInput("level_variable", "choose a filter method",
    unique(levels(data[["method"]])))
  sort.input = shiny::radioButtons("sort_type", "sort features", c("increasing", "decreasing", "none"))
  n.show.input = shiny::numericInput("n_show", "number of features to show",
    value = sum(fvalues$task.desc$n.feat),
    min = 1,
    max = sum(fvalues$task.desc$n.feat),
    step = 1)
  ui = shiny::shinyUI(
    shiny::pageWithSidebar(
      header,
      shiny::sidebarPanel(method.input, sort.input, n.show.input),
      shiny::mainPanel(shiny::uiOutput("ggvis_ui"), ggvis::ggvisOutput("ggvis"))
    )
  )
  server = shiny::shinyServer(function(input, output) {
    plt = shiny::reactive(
      createPlot(
        data = genPlotData(
          data[which(data[["method"]] == input$level_variable), ],
          input$sort_type,
          "value",
          "name",
          input$n_show
        ),
        feat.type.cols
      )
    )
    ggvis::bind_shiny(plt, "ggvis", "ggvis_ui")
  })
  shiny::shinyApp(ui, server)
}


#' @title Plot learning curve data using ggvis.
#'
#' @family plot
#' @family learning_curve
#'
#' @description
#' Visualizes data size (percentage used for model) vs. performance measure(s).
#'
#' @param obj ([LearningCurveData])\cr
#'   Result of [generateLearningCurveData].
#' @param interaction (`character(1)`)\cr
#'   Selects \dQuote{measure} or \dQuote{learner} to be used in a Shiny application
#'   making the `interaction` variable selectable via a dropdown menu.
#'   This variable must have more than one unique value, otherwise it will be ignored.
#'   The variable not chosen is mapped to color if it has more than one unique value.
#'   Note that if there are multiple learners and multiple measures interactivity is
#'   necessary as ggvis does not currently support facetting or subplots.
#'   The default is \dQuote{measure}.
#' @param pretty.names (`logical(1)`)\cr
#'   Whether to use the [Measure] name instead of the id in the plot.
#'   Default is `TRUE`.
#' @template ret_ggv
#' @export
plotLearningCurveGGVIS = function(obj, interaction = "measure", pretty.names = TRUE) {
  requirePackages("_ggvis")
  assertClass(obj, "LearningCurveData")
  mappings = c("measure", "learner")
  assertChoice(interaction, mappings)
  assertFlag(pretty.names)
  color = mappings[mappings != interaction]

  if (pretty.names) {
    mnames = replaceDupeMeasureNames(obj$measures, "name")
    colnames(obj$data) = mapValues(colnames(obj$data),
      names(obj$measures),
      mnames)
  }

  data = setDF(melt(as.data.table(obj$data), id.vars = c("learner", "percentage"), variable.name = "measure", value.name = "performance"))
  nmeas = length(unique(data$measure))
  nlearn = length(unique(data$learner))

  if ((color == "learner" & nlearn == 1L) | (color == "measure" & nmeas == 1L)) {
    color = NULL
  }
  if ((interaction == "learner" & nlearn == 1L) | (interaction == "measure" & nmeas == 1L)) {
    interaction = NULL
  }

  createPlot = function(data, color) {
    if (!is.null(color)) {
      plt = ggvis::ggvis(data, ggvis::prop("x", as.name("percentage")),
        ggvis::prop("y", as.name("performance")),
        ggvis::prop("stroke", as.name(color)))
      plt = ggvis::layer_points(plt, ggvis::prop("fill", as.name(color)))
    } else {
      plt = ggvis::ggvis(data, ggvis::prop("x", as.name("percentage")),
        ggvis::prop("y", as.name("performance")))
      plt = ggvis::layer_points(plt)
    }
    ggvis::layer_lines(plt)
  }

  if (!is.null(interaction)) {
    requirePackages("_shiny")
    ui = shiny::shinyUI(
      shiny::pageWithSidebar(
        shiny::headerPanel("learning curve"),
        shiny::sidebarPanel(
          shiny::selectInput("interaction_select",
            stri_paste("choose a", interaction, sep = " "),
            levels(data[[interaction]]))
        ),
        shiny::mainPanel(
          shiny::uiOutput("ggvis_ui"),
          ggvis::ggvisOutput("ggvis")
        )
    ))
    server = shiny::shinyServer(function(input, output) {
      data.sub = shiny::reactive(data[which(data[[interaction]] == input$interaction_select), ])
      plt = createPlot(data.sub, color)
      ggvis::bind_shiny(plt, "ggvis", "ggvis_ui")
    })
    shiny::shinyApp(ui, server)
  } else {
    createPlot(data, color)
  }
}

#' @title Plot a partial dependence using ggvis.
#' @description
#' Plot partial dependence from [generatePartialDependenceData] using ggvis.
#'
#' @family partial_dependence
#' @family plot
#'
#' @param obj [PartialDependenceData]\cr
#'   Generated by [generatePartialDependenceData].
#' @param interact (`character(1)`)\cr
#'   The name of a feature to be mapped to an interactive sidebar using Shiny.
#'   This feature must have been an element of the `features` argument to
#'   [generatePartialDependenceData] and is only applicable when said argument had length
#'   greater than 1.
#'   If [generatePartialDependenceData] is called with the `interaction` argument `FALSE`
#'   (the default) with argument `features` of length greater than one, then `interact` is ignored and
#'   the feature displayed is controlled by an interactive side panel.
#'   Default is `NULL`.
#' @param p (`numeric(1)`)\cr
#'   If `individual = TRUE` then `sample` allows the user to sample without replacement
#'   from the output to make the display more readable. Each row is sampled with probability `p`.
#'   Default is `1`.
#' @template ret_ggv
#' @export
plotPartialDependenceGGVIS = function(obj, interact = NULL, p = 1) {
  requirePackages("_ggvis")
  assertClass(obj, "PartialDependenceData")
  .Deprecated("plotPartialDependence")

  if (!is.null(interact)) {
    assertChoice(interact, obj$features)
  }
  if (obj$interaction & length(obj$features) > 2L) {
    stop("It is only possible to plot 2 features with this function.")
  }

  if (!obj$interaction & !is.null(interact)) {
    stop("obj argument created by generatePartialDependenceData was called with interaction = FALSE!")
  }

  if (p != 1) {
    assertNumber(p, lower = 0, upper = 1, finite = TRUE)
    if (!obj$individual) {
      stop("obj argument created by generatePartialDependenceData must be called with individual = TRUE to use this argument!")
    }
    rows = unique(obj$data$idx)
    id = sample(rows, size = floor(p * length(rows)))
    obj$data = obj$data[which(obj$data$idx %in% id), ]
  }

  if (obj$interaction & length(obj$features) == 2L) {
    if (is.null(interact)) {
      interact = obj$features[which.max(sapply(obj$features, function(x) length(unique(obj$data[[x]]))))]
    }
    x = obj$features[which(obj$features != interact)]
    if (is.factor(obj$data[[interact]])) {
      choices = levels(obj$data[[interact]])
    } else {
      choices = sort(unique(obj$data[[interact]]))
    }
  } else if (!obj$interaction & length(obj$features) > 1L) {
    id = colnames(obj$data)[!colnames(obj$data) %in% obj$features]
    obj$data = melt(obj$data, id.vars = id, variable.name = "Feature",
      value.name = "Value", na.rm = TRUE, variable.factor = TRUE)
    interact = "Feature"
    choices = obj$features
  } else {
    interact = NULL
  }

  bounds = all(c("lower", "upper") %in% colnames(obj$data) & obj$task.desc$type %in% c("surv", "regr"))

  if (obj$task.desc$type %in% c("regr", "classif")) {
    target = obj$task.desc$target
  } else {
    target = "Risk"
  }

  createPlot = function(td, target, interaction, individual, data, x, bounds) {
    classif = td$type == "classif" & all(target %in% td$class.levels)
    if (classif) {
      if (interaction) {
        plt = ggvis::ggvis(data,
          ggvis::prop("x", as.name(x)),
          ggvis::prop("y", as.name("Probability")),
          ggvis::prop("stroke", as.name("Class")))
      } else if (!interaction & !is.null(interact)) { ## no interaction but multiple features
        plt = ggvis::ggvis(data,
          ggvis::prop("x", as.name("Value")),
          ggvis::prop("y", as.name("Probability")),
          ggvis::prop("stroke", as.name("Class")))
      } else {
        plt = ggvis::ggvis(data,
          ggvis::prop("x", as.name(x)),
          ggvis::prop("y", as.name("Probability")),
          ggvis::prop("stroke", as.name("Class")))
      }

    } else { ## regression/survival
      if (interaction) {
        plt = ggvis::ggvis(data,
          ggvis::prop("x", as.name(x)),
          ggvis::prop("y", as.name(target)))
      } else if (!interaction & !is.null(interact)) {
        plt = ggvis::ggvis(data,
          ggvis::prop("x", as.name("Value")),
          ggvis::prop("y", as.name(target)))
      } else {
        plt = ggvis::ggvis(data,
          ggvis::prop("x", as.name(x)),
          ggvis::prop("y", as.name(target)))
      }
    }

    if (bounds) {
      plt = ggvis::layer_ribbons(plt,
        ggvis::prop("y", as.name("lower")),
        ggvis::prop("y2", as.name("upper")),
        ggvis::prop("opacity", .5))
    }
    if (individual) {
      plt = ggvis::group_by_.ggvis(plt, as.name("idx"))
      plt = ggvis::layer_paths(plt, ggvis::prop("opacity", .25))
    } else {
      if (classif) {
        plt = ggvis::layer_points(plt, ggvis::prop("fill", as.name("Class")))
      } else {
        plt = ggvis::layer_points(plt)
      }
      plt = ggvis::layer_lines(plt)
    }

    plt
  }

  if (obj$derivative) {
    header = stri_paste(target, "(derivative)", sep = " ")
  } else {
    header = target
  }

  if (!is.null(interact)) {
    requirePackages("_shiny")
    panel = shiny::selectInput("interaction_select", interact, choices)
    ui = shiny::shinyUI(
      shiny::pageWithSidebar(
        shiny::headerPanel(header),
        shiny::sidebarPanel(panel),
        shiny::mainPanel(
          shiny::uiOutput("ggvis_ui"),
          ggvis::ggvisOutput("ggvis")
        )
    ))
    server = shiny::shinyServer(function(input, output) {
      plt = shiny::reactive(createPlot(obj$task.desc, obj$target, obj$interaction, obj$individual,
        obj$data[obj$data[[interact]] == input$interaction_select, ],
        x, bounds))
      ggvis::bind_shiny(plt, "ggvis", "ggvis_ui")
    })
    shiny::shinyApp(ui, server)
  } else {
    createPlot(obj$task.desc, obj$target, obj$interaction, obj$individual, obj$data, obj$features, bounds)
  }
}

#' @title Plot threshold vs. performance(s) for 2class classification using ggvis.
#'
#' @description
#' Plots threshold vs. performance(s) data that has been generated with [generateThreshVsPerfData].
#'
#' @family plot
#' @family thresh_vs_perf
#'
#' @param obj ([ThreshVsPerfData])\cr
#'   Result of [generateThreshVsPerfData].
#' @param mark.th (`numeric(1)`)\cr
#'   Mark given threshold with vertical line?
#'   Default is `NA` which means not to do it.
#' @param interaction (`character(1)`)\cr
#'   Selects \dQuote{measure} or \dQuote{learner} to be used in a Shiny application
#'   making the `interaction` variable selectable via a dropdown menu.
#'   This variable must have more than one unique value, otherwise it will be ignored.
#'   The variable not chosen is mapped to color if it has more than one unique value.
#'   Note that if there are multiple learners and multiple measures interactivity is
#'   necessary as ggvis does not currently support facetting or subplots.
#'   The default is \dQuote{measure}.
#' @param pretty.names (`logical(1)`)\cr
#'   Whether to use the [Measure] name instead of the id in the plot.
#'   Default is `TRUE`.
#' @template ret_ggv
#' @export
#' @examples \dontrun{
#' lrn = makeLearner("classif.rpart", predict.type = "prob")
#' mod = train(lrn, sonar.task)
#' pred = predict(mod, sonar.task)
#' pvs = generateThreshVsPerfData(pred, list(tpr, fpr))
#' plotThreshVsPerfGGVIS(pvs)
#' }
plotThreshVsPerfGGVIS = function(obj, interaction = "measure", mark.th = NA_real_, pretty.names = TRUE) {
  requirePackages("_ggvis")
  assertClass(obj, classes = "ThreshVsPerfData")
  mappings = c("measure", "learner")
  assertChoice(interaction, mappings)
  assertFlag(pretty.names)
  color = mappings[mappings != interaction]

  if (pretty.names) {
    mnames = replaceDupeMeasureNames(obj$measures, "name")
    colnames(obj$data) = mapValues(colnames(obj$data),
      names(obj$measures),
      mnames)
  } else {
    mnames = names(obj$measures)
  }

  id.vars = "threshold"
  resamp = "iter" %in% colnames(obj$data)
  if (resamp) id.vars = c(id.vars, "iter")
  if ("learner" %in% colnames(obj$data)) id.vars = c(id.vars, "learner")

  data = setDF(data.table(melt(as.data.table(obj$data), measure.vars = mnames, variable.name = "measure", value.name = "performance", id.vars = id.vars)))
  nmeas = length(unique(data$measure))

  if (!is.null(data$learner)) {
    nlearn = length(unique(data$learner))
  } else {
    nlearn = 1L
  }

  if ((color == "learner" && nlearn == 1L) || (color == "measure" && nmeas == 1L)) {
    color = NULL
  }
  if ((interaction == "learner" && nlearn == 1L) || (interaction == "measure" && nmeas == 1L)) {
    interaction = NULL
  }

  if (resamp && !obj$aggregate && is.null(color)) {
    group = "iter"
  } else if (resamp && !obj$aggregate && !is.null(color)) {
    group = c("iter", color)
  } else {
    group = NULL
  }

  createPlot = function(data, color = NULL, group = NULL, measures) {
    if (!is.null(color)) {
      plt = ggvis::ggvis(data, ggvis::prop("x", as.name("threshold")), ggvis::prop("y", as.name("performance")),
        ggvis::prop("stroke", as.name(color)))
    } else {
      plt = ggvis::ggvis(data, ggvis::prop("x", as.name("threshold")), ggvis::prop("y", as.name("performance")))
    }

    if (!is.null(group)) {
      plt = ggvis::group_by(plt, .dots = group)
    }

    plt = ggvis::layer_paths(plt)

    if (!is.na(mark.th) && is.null(interaction)) { ## cannot do vline with reactive data
      vline.data = data.frame(x2 = rep(mark.th, 2), y2 = c(min(data$perf), max(data$perf)),
        measure = obj$measures[1])
      plt = ggvis::layer_lines(plt, ggvis::prop("x", as.name("x2")),
        ggvis::prop("y", as.name("y2")),
        ggvis::prop("stroke", "grey", scale = FALSE), data = vline.data)
    }
    plt = ggvis::add_axis(plt, "x", title = "threshold")

    if (length(measures) > 1L) {
      plt = ggvis::add_axis(plt, "y", title = "performance")
    } else {
      plt = ggvis::add_axis(plt, "y", title = measures[[1]]$name)
    }

    plt
  }

  if (!is.null(interaction)) {
    requirePackages("_shiny")
    ui = shiny::shinyUI(
      shiny::pageWithSidebar(
        shiny::headerPanel("Threshold vs. Performance"),
        shiny::sidebarPanel(
          shiny::selectInput("interaction_select",
            stri_paste("choose a", interaction, sep = " "),
            levels(data[[interaction]]))
        ),
        shiny::mainPanel(
          shiny::uiOutput("ggvis_ui"),
          ggvis::ggvisOutput("ggvis")
        )
    ))
    server = shiny::shinyServer(function(input, output) {
      data.sub = shiny::reactive(data[which(data[[interaction]] == input$interaction_select), ])
      plt = createPlot(data.sub, color, group, obj$measures)
      ggvis::bind_shiny(plt, "ggvis", "ggvis_ui")
    })
    shiny::shinyApp(ui, server)
  } else {
    createPlot(data, color, group, obj$measures)
  }
}


#' @title Plots multicriteria results after tuning using ggvis.
#'
#' @description
#' Visualizes the pareto front and possibly the dominated points.
#'
#' @param res ([TuneMultiCritResult])\cr
#'   Result of [tuneParamsMultiCrit].
#' @param path (`logical(1)`)\cr
#'   Visualize all evaluated points (or only the nondominated pareto front)?
#'   Points are colored according to their location.
#'   Default is `TRUE`.
#' @param point.info (`character(1)`)\cr
#'   Show for each point which hyper parameters led to this point?
#'   For `"click"`, information is displayed on mouse click.
#'   For `"hover"`, information is displayed on mouse hover.
#'   If set to `"none"`, no information is displayed.
#'   Default is `"hover"`.
#' @param point.trafo (`logical(1)`)\cr
#'   Should the information show the transformed hyper parameters?
#'   Default is `TRUE`.
#'
#' @template ret_ggv
#' @family tune_multicrit
#' @export
#' @examples
#' # see tuneParamsMultiCrit
plotTuneMultiCritResultGGVIS = function(res, path = TRUE, point.info = "hover", point.trafo = TRUE) {
  requirePackages("_ggvis")
  assertClass(res, "TuneMultiCritResult")
  assertFlag(path)
  assertChoice(point.info, choices = c("click", "hover", "none"))
  assertFlag(point.trafo)

  plt.data = as.data.frame(res$opt.path)
  plt.data$location = factor(row.names(plt.data) %in% res$ind, levels = c(TRUE, FALSE),
    labels = c("frontier", "interior"))
  plt.data$id = seq_len(nrow(plt.data))

  if (point.trafo) {
    for (param in res$opt.path$par.set$pars) {
      plt.data[, param$id] = trafoValue(param, plt.data[, param$id])
    }
  }

  if (!path) {
    plt.data = plt.data[plt.data$location == "frontier", , drop = FALSE]
  }

  info = function(x) {
    if (is.null(x)) {
      return(NULL)
    }
    n = length(res$x[[1]])
    row = plt.data[plt.data$id == x$id, ][1:n]
    text = paste0(names(row), ": ", format(row, zero.print = TRUE), collapse = "<br />")
    return(text)
  }

  plt = ggvis::ggvis(plt.data, ggvis::prop("x", as.name(colnames(res$y)[1L])),
    ggvis::prop("y", as.name(colnames(res$y)[2L])),
    ggvis::prop("key", ~id, scale = FALSE))
  plt = ggvis::layer_points(plt, ggvis::prop("fill", as.name("location")))

  if (point.info != "none") {
    plt = ggvis::add_tooltip(plt, info, point.info)
  }

  return(plt)
}
berndbischl/mlr documentation built on Jan. 6, 2023, 12:45 p.m.