R/plot2Way.R

Defines functions plot2Way

Documented in plot2Way

#' Plot two-way interactions from gbm model
#'
#' @param dw_model Model object from running [buildMod()].
#' @param variable The variables to plot. Must be of length two e.g. `variables
#'   = c("ws", "wd"`.
#' @param res Resolution in x-y i.e. number of points in each dimension.
#' @param exclude Should surfaces exclude predictions too far from original data?
#'   The default is `TRUE`.
#' @param cols Colours to be used for plotting. Options include
#'   \dQuote{default}, \dQuote{increment}, \dQuote{heat}, \dQuote{jet} and user
#'   defined. For user defined the user can supply a list of colour names
#'   recognised by R (type `colours()` to see the full list). An example would
#'   be `cols = c("yellow", "green", "blue")`
#' @param dist When plotting surfaces, `dist` controls how far from the original
#'   data the predictions should be made. See `exclude.too.far` from the `mgcv`
#'   package. Data are first transformed to a unit square. Values should be
#'   between 0 and 1.
#' @param plot Should a plot be produced? `FALSE` can be useful when analysing
#'   data to extract plot components and plotting them in other ways.
#' @param ... Other arguments to be passed for plotting.
#' @export
#' @return To add
#' @family deweather model plotting functions
#' @author David Carslaw
plot2Way <- function(dw_model,
                     variable = c("ws", "air_temp"),
                     res = 100,
                     exclude = TRUE,
                     cols = "default",
                     dist = 0.05,
                     plot = TRUE,
                     ...) {
  check_dwmod(dw_model)

  ## extract from deweather object
  data <- dw_model$data
  mod <- dw_model$model

  res <- gbm::plot.gbm(
    mod,
    i.var = variable,
    continuous.resolution = res,
    return.grid = TRUE
  )

  ## exclude predictions too far from data (from mgcv)

  if (exclude && all(sapply(res[variable], is.numeric))) {
    sub <- stats::na.omit(data[, variable]) ## pairs of variables
    x <- sub[[variable[1]]] ## x data
    y <- sub[[variable[2]]] ## y data

    mx <- unique(res[, 1])
    my <- unique(res[, 2])
    n <- length(mx)
    gx <- rep(mx, n)
    gy <- rep(my, rep(n, n))
    tf <- mgcv::exclude.too.far(gx, gy, x, y, dist)

    res$y[tf] <- NA
  }

  if ("trend" %in% names(res)) {
    res <- decimalDate(res, "trend")
    res$trend <- res$date
  }

  if (all(sapply(res, is.numeric))) {
    var1 <- variable[1]
    var2 <- variable[2]


    plt <-
      ggplot2::ggplot(res, ggplot2::aes(.data[[var1]], .data[[var2]], fill = .data[["y"]])) +
      ggplot2::geom_raster() +
      ggplot2::scale_fill_gradientn(
        colours = openair::openColours(cols, 100),
        na.value = "transparent"
      )
    
    if (any(is.na(res$y))) {
      plt <- plt +
        ggplot2::labs(fill = openair::quickText(mod$response.name))
    }
    
    if (plot) {
      print(plt)
    }
  } else {
    var1 <- variable[1]
    var2 <- variable[2]
    
    ## need to rename variables that use openair dates
    if ("hour" %in% variable) {
      id <- which(variable == "hour")
      variable[id] <- "Hour"
      var2 <- variable[which(variable != "Hour")]
      res <- dplyr::rename(res, Hour = .data$hour)
      #  res$Hour <- factor(round(res$Hour))
      var1 <- "Hour"
    }
    
    if ("weekday" %in% variable) {
      id <- which(variable == "weekday")
      variable[id] <- "Weekday"
      var2 <- variable[which(variable != "Weekday")]
      res <- dplyr::rename(res, Weekday = .data$weekday)
      
      weekday.names <- format(ISOdate(2000, 1, 2:8), "%a")
      levels(res$Weekday) <- sort(weekday.names)
      res$Weekday <- ordered(res$Weekday, levels = weekday.names)
      var1 <- "Weekday"
    }
    
    plt <-
      ggplot2::ggplot(res, ggplot2::aes(.data[[var1]], .data[[var2]], fill = .data[["y"]])) +
      ggplot2::geom_tile() +
      ggplot2::scale_fill_gradientn(
        colours = openair::openColours(cols, 100),
        na.value = "transparent"
      ) +
      ggplot2::labs(fill = openair::quickText(mod$response.name))
    
    if (plot) {
      print(plt)
    }
  }
  
  invisible(list(plot = plt, data = dplyr::tibble(res)))
}
davidcarslaw/deweather documentation built on March 27, 2024, 8:18 a.m.