R/plot.R

Defines functions overlay_ringsearch overlay_piecharts overlay_sources overlay_risk_map overlay_surface overlay_geoprofile overlay_spatial_prior overlay_points overlay_trial_sites overlay_sentinels plot_lorenz plot_DIC_gelman plot_map plot_loglike_diagnostic plot_density plot_acf plot_trace plot_alpha plot_expected_popsize plot_sigma plot_structure plot_coupling plot_loglike theme_empty more_colours col_tim col_hotcold

Documented in more_colours overlay_geoprofile overlay_piecharts overlay_points overlay_ringsearch overlay_risk_map overlay_sentinels overlay_sources overlay_spatial_prior overlay_surface overlay_trial_sites plot_acf plot_alpha plot_coupling plot_density plot_DIC_gelman plot_expected_popsize plot_loglike plot_loglike_diagnostic plot_lorenz plot_map plot_sigma plot_structure plot_trace

#------------------------------------------------
# red-to-blue colours
#' @importFrom grDevices colorRampPalette
#' @noRd
col_hotcold <- function(n = 6) {
  raw_cols <- c("#D73027", "#FC8D59", "#FEE090", "#E0F3F8", "#91BFDB", "#4575B4")
  my_pal <- colorRampPalette(raw_cols)
  return(my_pal(n))
}

#------------------------------------------------
# blue-to-red colours. Full credit to tim.colors from the fields package, from
# which these colours derive. Copied rather than including the fields package to
# avoid dependency on another package for the sake of a single colour scheme.
#' @importFrom grDevices colorRampPalette
#' @noRd
col_tim <- function(n = 10) {
  raw_cols <- c("#00008F", "#00009F", "#0000AF", "#0000BF",
                "#0000CF", "#0000DF", "#0000EF", "#0000FF", "#0010FF",
                "#0020FF", "#0030FF", "#0040FF", "#0050FF", "#0060FF",
                "#0070FF", "#0080FF", "#008FFF", "#009FFF", "#00AFFF",
                "#00BFFF", "#00CFFF", "#00DFFF", "#00EFFF", "#00FFFF",
                "#10FFEF", "#20FFDF", "#30FFCF", "#40FFBF", "#50FFAF",
                "#60FF9F", "#70FF8F", "#80FF80", "#8FFF70", "#9FFF60",
                "#AFFF50", "#BFFF40", "#CFFF30", "#DFFF20", "#EFFF10",
                "#FFFF00", "#FFEF00", "#FFDF00", "#FFCF00", "#FFBF00",
                "#FFAF00", "#FF9F00", "#FF8F00", "#FF8000", "#FF7000",
                "#FF6000", "#FF5000", "#FF4000", "#FF3000", "#FF2000",
                "#FF1000", "#FF0000", "#EF0000", "#DF0000", "#CF0000",
                "#BF0000", "#AF0000", "#9F0000", "#8F0000", "#800000")
  my_pal <- colorRampPalette(raw_cols)
  return(my_pal(n))
}

#------------------------------------------------
#' @title Expand series of colours by interpolation
#'
#' @description Expand a series of colours by interpolation to produce any
#'   number of colours from a given series. The pattern of interpolation is
#'   designed so that (n+1)th value contains the nth value plus one more colour,
#'   rather than being a completely different series. For example, running
#'   \code{more_colours(5)} and \code{more_colours(4)}, the first 4 colours will
#'   be shared between the two series.
#'
#' @param n how many colours to return.
#' @param raw_cols vector of colours to interpolate.
#'
#' @import RColorBrewer
#' @importFrom grDevices colorRampPalette
#' @export

more_colours <- function(n = 5, 
                         raw_cols = brewer.pal(12, "Paired")) {

  # check inputs
  assert_single_pos_int(n, zero_allowed = FALSE)
  assert_string(raw_cols)
  assert_vector(raw_cols)
  assert_greq(length(raw_cols), 2)

  # simple case if n within raw cols
  n_raw <- length(raw_cols)
  if (n <= n_raw) {
    return(raw_cols[1:n])
  }

  # generate colour palette from raw colours
  my_palette <- colorRampPalette(raw_cols)

  # interpolate colours by repeatedly splitting the [0,1] interval until we have
  # enough values. n_steps is the number of times we have to do this. n_breaks
  # is the number of breaks for each step
  n_steps <- ceiling(log(n-1)/log(2) - log(n_raw-1)/log(2)) + 1
  n_breaks <- (n_raw-1)*2^(1:n_steps - 1) + 1

  # split the [0,1] interval this many times and drop duplicated values
  s <- unlist(mapply(function(x) seq(0,1,l=x), n_breaks, SIMPLIFY = FALSE))
  s <- s[!duplicated(s)]

  # convert s to integer index
  w <- match(s, seq(0,1,l = n_breaks[n_steps]))
  w <- w[1:n]

  # get final colours
  all_cols <- my_palette(n_breaks[n_steps])
  ret <- all_cols[w]

  return(ret)
}

#------------------------------------------------
# ggplot theme with minimal objects
#' @import ggplot2
#' @noRd
theme_empty <- function() {
  theme(axis.line = element_blank(),
        axis.text.x = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks = element_blank(),
        panel.background = element_blank(),
        panel.border = element_blank(),
        panel.grid.major = element_blank(),
        panel.grid.minor = element_blank(),
        plot.background = element_blank())
}

#------------------------------------------------
#' @title Plot loglikelihood 95\% credible intervals
#'
#' @description Plot loglikelihood 95\% credible intervals of current active set
#'
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param K which value of K to produce the plot for.
#' @param x_axis_type how to format the x-axis. 1 = integer rungs, 2 = values of
#'   beta raised to the GTI power.
#' @param y_axis_type how to format the y-axis. 1 = raw values, 2 = truncated at
#'   auto-chosen lower limit. 3 = double-log scale.
#' @param phase which phase to plot. Must be either "burnin" or "sampling".
#' @param legend switches the legend for thermodynamic power on or off
#' 
#' @import ggplot2
#' @importFrom grDevices grey
#' @export

plot_loglike <- function(project, 
                         K = 1, 
                         x_axis_type = 1, 
                         y_axis_type = 1, 
                         phase = "sampling", 
                         legend = TRUE) {
  
  # check inputs
  assert_custom_class(project, "rgeoprofile_project")
  assert_single_pos_int(K, zero_allowed = FALSE)
  assert_in(x_axis_type, 1:2)
  assert_in(y_axis_type, 1:3)
  assert_in(phase, c("burnin", "sampling"))
  assert_single_logical(legend)
  
  # get output
  beta_vec <- get_output(project, type = "summary", name = "beta_vec", K = K)
  rungs <- length(beta_vec)
  s <- project$active_set
  
  # define x-axis type
  if (x_axis_type == 1) {
    x_vec <- 1:rungs
    x_lab <- "rung"
    x_mid <- (2:rungs) - 0.5
  } else {
    x_vec <- beta_vec
    x_lab <- "thermodynamic power"
    x_mid <- beta_vec[-1] - diff(beta_vec)/2
  }
  
  # get plotting data
  if (phase == "burnin") {
    loglike <- get_output(project, "loglike_burnin", type = "raw", K = K)
  } else {
    loglike <- get_output(project, "loglike_sampling", type = "raw", K = K)
  }
  y_lab <- "log-likelihood"
  
  # move to plotting deviance if specified
  if (y_axis_type == 3) {
    loglike <- -2 * loglike
    y_lab <- "deviance"
    
    # if needed, scale by adding/subtracting a power of ten until all values are
    # positive
    if (min(loglike) < 0) {
      dev_scale_power <- ceiling(log(abs(min(loglike)))/log(10))
      dev_scale_sign <- -sign(min(loglike))
      loglike <- loglike + dev_scale_sign*10^dev_scale_power
      
      dev_scale_base <- ifelse(dev_scale_power == 0, 1, 10)
      dev_scale_power_char <- ifelse(dev_scale_power <= 1, "", paste("^", dev_scale_power))
      dev_scale_sign_char <- ifelse(dev_scale_sign < 0, "-", "+")
      y_lab <- parse(text = paste("deviance", dev_scale_sign_char, dev_scale_base, dev_scale_power_char))
    }
  }
  
  # get 95% credible intervals over plotting values
  loglike_intervals <- t(apply(loglike, 2, quantile_95))
  
  # get data into ggplot format and define temperature colours
  df <- as.data.frame(loglike_intervals)
  df$col <- beta_vec
  
  # produce plot
  plot1 <- ggplot(df) + theme_bw() + theme(panel.grid.minor.x = element_blank(),
                                           panel.grid.major.x = element_blank())
  plot1 <- plot1 + geom_vline(aes(xintercept = x_vec), col = grey(0.9))
  plot1 <- plot1 + geom_segment(aes_(x = ~x_vec, y = ~Q2.5, xend = ~x_vec, yend = ~Q97.5))
  plot1 <- plot1 + geom_point(aes_(x = ~x_vec, y = ~Q50, color = ~col))
  plot1 <- plot1 + xlab(x_lab) + ylab(y_lab) + theme(legend.position = "none")
  
  if(legend == TRUE){
    plot1 <- plot1 + scale_colour_gradientn(colours = c("red", "blue"), name = "thermodynamic\npower", limits = c(0,1))
  }
  
  # define y-axis
  if (y_axis_type == 2) {
    y_min <- quantile(df$Q2.5, probs = 0.5)
    y_max <- max(df$Q97.5)
    plot1 <- plot1 + coord_cartesian(ylim = c(y_min, y_max))
  } else if (y_axis_type == 3 & legend == TRUE) {
    plot1 <- plot1 + scale_y_continuous(trans = "log10")
  }
  
  # return plot object
  return(plot1)
}

#------------------------------------------------
#' @title Plot acceptance rate between rungs
#'
#' @description For each pair of rungs, plot the acceptance rate between them. 
#'
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param K which value of K to produce the plot for.
#' @param phase which phase to plot. Must be either "burnin" or "sampling".
#'
#' @import ggplot2
#' @importFrom grDevices grey
#' @export

plot_coupling <- function(project, 
                          K = 1, 
                          phase = "sampling") {
  
  # check inputs
  assert_custom_class(project, "rgeoprofile_project")
  assert_single_pos_int(K, zero_allowed = FALSE)
  assert_in(phase, c("burnin", "sampling"))
  
  # get plotting data
  if (phase == "burnin") {
    coupling <- get_output(project, "coupling_accept_burnin", type = "summary", K = K)
  } else {
    coupling <- get_output(project, "coupling_accept_sampling", type = "summary", K = K)
  }
  
  # get x values
  x_vec <- 1:length(coupling) + 0.5
  
  # create plotting dataframe
  df_plot <- data.frame(x = x_vec, y = coupling)
  
  # produce plot
  plot1 <- ggplot(df_plot) + theme_bw() + theme(panel.grid.minor.x = element_blank(),
                                           panel.grid.major.x = element_blank())
  plot1 <- plot1 + geom_point(aes_(x = ~x, y = ~y))
  plot1 <- plot1 + geom_vline(aes_(xintercept = ~x-0.5), col = grey(0.9))
  plot1 <- plot1 + geom_vline(aes(xintercept = length(coupling)+1), col = grey(0.9))
  plot1 <- plot1 + ylim(c(0, 1)) + xlim(c(1, length(coupling)+1))
  plot1 <- plot1 + xlab("rung") + ylab("coupling acceptance rate")
  
  # return plot object
  return(plot1)
}

#------------------------------------------------
#' @title Posterior allocation plot
#'
#' @description Produce posterior allocation plot of current active set.
#'
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param K which value of K to produce the plot for.
#' @param divide_ind_on whether to add dividing lines between bars.
#'
#' @import ggplot2
#' @export
#' 
#' @examples
#' \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' # Plot the structure for a single K value.
#' plot_structure(project = p, K = 2)
#' # Similarly, plot the allocation structure for every K.
#' plot_structure(project = p, divide_ind_on = TRUE)

plot_structure <- function(project, 
                           K = NULL, 
                           divide_ind_on = FALSE) {

  # check inputs
  assert_custom_class(project, "rgeoprofile_project")
  if (!is.null(K)) {
    assert_pos_int(K)
  }
  assert_single_logical(divide_ind_on)

  # get active set and check non-zero
  s <- project$active_set
  if (s == 0) {
    stop("no active parameter set")
  }

  # set default K to all values with output
  null_output <- mapply(function(x) {is.null(x$summary$qmatrix)}, project$output$single_set[[s]]$single_K)
  if (all(null_output)) {
    stop("no output for active parameter set")
  }
  K <- define_default(K, which(!null_output))

  # check output exists for chosen K
  qmatrix_list <- list()
  for (i in 1:length(K)) {
    qmatrix_list[[i]] <- project$output$single_set[[s]]$single_K[[K[i]]]$summary$qmatrix
    if (is.null(qmatrix_list[[i]])) {
      stop(sprintf("no qmatrix output for K = %s of active set", K[i]))
    }
  }

  # get data into ggplot format
  df <- NULL
  for (i in 1:length(K)) {
    m <- unclass(qmatrix_list[[i]])
    m <- m[!is.na(m[,1]), , drop = FALSE]
    n <- nrow(m)
    df <- rbind(df, 
                data.frame(K = as.numeric(K[i]), 
                           ind = rep(1:n, each = K[i]), 
                           k = as.factor(rep(1:K[i], times = n)), 
                           val = as.vector(t(m))))
  }

  # produce basic plot
  plot1 <- ggplot(df) + theme_empty()
  plot1 <- plot1 + geom_bar(aes_(x = ~ind, y = ~val, fill = ~k), width = 1, stat = "identity")
  plot1 <- plot1 + scale_x_continuous(expand = c(0,0)) + scale_y_continuous(expand = c(0,0))
  plot1 <- plot1 + xlab("positive sentinel site")

  # arrange in rows
  if (length(K) == 1) {
    plot1 <- plot1 + facet_wrap(~K, ncol = 1)
    plot1 <- plot1 + theme(strip.background = element_blank(), strip.text = element_blank())
    plot1 <- plot1 + ylab("probability")
  } else {
    plot1 <- plot1 + facet_wrap(~K, ncol = 1, strip.position = "left")
    plot1 <- plot1 + theme(strip.background = element_blank())
    plot1 <- plot1 + ylab("K")
  }

  # add legends
  plot1 <- plot1 + scale_fill_manual(values = more_colours(max(K)), name = "group")
  plot1 <- plot1 + scale_colour_manual(values = "white")
  plot1 <- plot1 + guides(colour = FALSE)

  # add border
  plot1 <- plot1 + theme(panel.border = element_rect(colour = "black", size = 2, fill = NA))

  # optionally add dividing lines
  if (divide_ind_on) {
    plot1 <- plot1 + geom_segment(aes_(x = ~x, y = ~y, xend = ~x, yend = ~y+1, col = "white"), size = 0.3, data = data.frame(x = 1:n-0.5, y = rep(0,n)))
  }

  return(plot1)
}

#------------------------------------------------
#' @title Plot sigma 95\% credible intervals
#'
#' @description Plot credible intervals for a "single" (1) or "independent" (K)
#'  sigma model.
#'
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param K which value of K to plot.
#'
#' @import ggplot2
#' @export
#' 
#' @examples
#' \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' plot_sigma(project = p)

plot_sigma <- function(project, 
                       K = NULL) {

  # check inputs
  assert_custom_class(project, "rgeoprofile_project")
  if (!is.null(K)) {
    assert_single_pos_int(K, zero_allowed = FALSE)
  }
  
  # get active set and check non-zero
  s <- project$active_set
  if (s == 0) {
    stop("no active parameter set")
  }
  
  # get sigma model
  sigma_model <- project$parameter_sets[[s]]$sigma_model
  
  # get output
  sigma_intervals <- get_output(project, "sigma_intervals", K)
  x_lab <- "source"
  if (sigma_model == "single") {
    sigma_intervals <- sigma_intervals[which(rowSums(sigma_intervals) != 0),]
    rownames(sigma_intervals) <- ""
    x_lab <- paste("K = ", K, sep = "")
  }
  
  # get properties
  x_vec <- rownames(sigma_intervals)

  # produce plot
  plot1 <- ggplot(sigma_intervals) + theme_bw()
  plot1 <- plot1 + geom_segment(aes_(x = ~x_vec, y = ~Q2.5, xend = ~x_vec, yend = ~Q97.5))
  plot1 <- plot1 + geom_point(aes_(x = ~x_vec, y = ~Q50))
  plot1 <- plot1 + scale_y_continuous(limits = c(0, max(sigma_intervals$Q97.5)*1.1), expand = c(0,0))
  plot1 <- plot1 + xlab(x_lab) + ylab("sigma")

  # return plot object
  return(plot1)
}

#------------------------------------------------
#' @title Plot expected population size 95\% credible intervals
#'
#' @description Plot credible intervals for the entire expected population (single)
#'  or individuals sources (independent).
#'
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param K which value of K to produce the plot for.
#'
#' @import ggplot2
#' @export
#' 
#' @examples
#' \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' plot_expected_popsize(project = p, K = 1)

plot_expected_popsize <- function(project, 
                                  K = NULL) {

  # check inputs
  assert_custom_class(project, "rgeoprofile_project")
  assert_single_pos_int(K, zero_allowed = FALSE)
  
  # single or independent expected population size?
  s <- project$active_set
  expected_popsize_model <- project$parameter_sets[[s]]$expected_popsize_model
      
  # get output
  expected_popsize_intervals <- get_output(project, "expected_popsize_intervals", K)
  
  # set plotting params
  if(expected_popsize_model == "single"){
    K <- 1
    xlabel <- ""
  } else if(expected_popsize_model == "independent"){   
    xlabel <- as.character(1:K)
  }  
  
  # produce plot
  plot1 <- ggplot(expected_popsize_intervals) + theme_bw()
  plot1 <- plot1 + geom_segment(aes_(x = 1:K, y = ~Q2.5, xend = 1:K, yend = ~Q97.5))
  plot1 <- plot1 + geom_point(aes_(x = 1:K, y = ~Q50))
  plot1 <- plot1 + scale_y_continuous(limits = c(0, max(expected_popsize_intervals$Q97.5)*1.1), expand = c(0,0))
  plot1 <- plot1 + xlab(xlabel) + ylab("expected population size")

  # return plot object
  return(plot1)
}

#------------------------------------------------
#' @title Plot alpha 95\% credible intervals
#'
#' @description Plot the over dispersion parameter alpha, in accordance with 
#'              a negative binomial model
#'
#' @param project an RgeoProfile project, as produced by the function
#'                \code{rgeoprofile_project()}.
#' @param K which value of K to produce the plot for.
#'
#' @import ggplot2
#' @export
#' 
#' @examples 
#' # \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' # plot_alpha(project = p)

plot_alpha <- function(project, 
                       K = NULL) {

  # check inputs
  assert_custom_class(project, "rgeoprofile_project")
  assert_single_pos_int(K, zero_allowed = FALSE)
  
  # single or independent expected population size?
  s <- project$active_set
      
  # get output
  alpha_intervals <- get_output(project, "alpha_intervals", K)
  
  # produce plot
  plot1 <- ggplot(alpha_intervals) + theme_bw()
  plot1 <- plot1 + geom_segment(aes_(x = "", y = ~Q2.5, xend = "", yend = ~Q97.5))
  plot1 <- plot1 + geom_point(aes_(x = "", y = ~Q50))
  plot1 <- plot1 + scale_y_continuous(limits = c(0, max(alpha_intervals$Q97.5)*1.1), expand = c(0,0))
  plot1 <- plot1 + xlab("") + ylab("alpha") + ggtitle("Alpha - 95% confidence interval")

  # return plot object
  return(plot1)
}

#------------------------------------------------
#' @title Produce MCMC trace plot
#'
#' @description Produce MCMC trace plot of the log-likelihood at each iteration.
#'
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param K which value of K to plot.
#' @param rung which rung to plot. Defaults to the cold chain.
#' @param col colour of the trace.
#' @param phase plot the trace during the burnin or sampling phase
#' 
#' @import ggplot2
#' @export
#' 
#' @examples
#' \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' plot_trace(project = p, K = 2)
#' # Similarly, plot the trace for every K value.
#' plot_trace(project = p)

plot_trace <- function(project, 
                       K = NULL, 
                       rung = NULL, 
                       col = "black", 
                       phase = "sampling") {

  # check inputs
  assert_custom_class(project, "rgeoprofile_project")
  if (!is.null(K)) {
    assert_single_pos_int(K, zero_allowed = FALSE)
  }
  if (!is.null(rung)) {
    assert_single_pos_int(rung)
  }
  assert_in(phase, c("burnin", "sampling"))
  
  # get output
  if(phase == "sampling"){
    loglike_mat <- get_output(project, "loglike_sampling", K, "raw")
  } else if(phase == "burnin"){
    loglike_mat <- get_output(project, "loglike_burnin", K, "raw")
  }
  
  # use cold rung by default
  rungs <- ncol(loglike_mat)
  rung <- define_default(rung, rungs)
  assert_leq(rung, rungs)
  loglike <- as.vector(loglike_mat[,rung])

  # get into ggplot format
  df <- data.frame(x = 1:length(loglike), y = loglike)

  # produce plot
  plot1 <- ggplot(df) + theme_bw() + ylab("log-likelihood")

  # complete plot
  plot1 <- plot1 + geom_line(aes_(x = ~x, y = ~y, colour = "col1"))
  plot1 <- plot1 + coord_cartesian(xlim = c(0,nrow(df)))
  plot1 <- plot1 + scale_x_continuous(expand = c(0,0))
  plot1 <- plot1 + scale_colour_manual(values = col)
  plot1 <- plot1 + guides(colour = FALSE)
  plot1 <- plot1 + xlab("iteration")

  # return plot object
  return(plot1)
}

#------------------------------------------------
#' @title Produce MCMC autocorrelation plot
#'
#' @description Produce MCMC autocorrelation plot of the log-likelihood
#'
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param K which value of K to plot.
#' @param rung which rung to plot. Defaults to the cold chain.
#' @param col colour of the trace.
#' @param phase plot the acf during the burnin or sampling phase
#'
#' @import ggplot2
#' @export
#' 
#' @examples
#' \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' plot_acf(project = p)

plot_acf <- function(project, 
                     K = NULL, 
                     rung = NULL, 
                     col = "black", 
                     phase = "sampling") {

  # check inputs
  assert_custom_class(project, "rgeoprofile_project")
  if (!is.null(K)) {
    assert_single_pos_int(K, zero_allowed = FALSE)
  }
  if (!is.null(rung)) {
    assert_single_pos_int(rung)
  }
  assert_in(phase, c("burnin", "sampling"))

  # get output
  if(phase == "sampling"){
    loglike_mat <- get_output(project, "loglike_sampling", K, "raw")
  } else if(phase == "burnin"){
    loglike_mat <- get_output(project, "loglike_burnin", K, "raw")
  }
  
  # use cold rung by default
  rungs <- ncol(loglike_mat)
  rung <- define_default(rung, rungs)
  assert_leq(rung, rungs)
  loglike <- as.vector(loglike_mat[,rung])

  # store variable to plot
  v <- loglike

  # get autocorrelation
  lag_max <- round(3*length(v)/effectiveSize(v))
  lag_max <- max(lag_max, 20)
  lag_max <- min(lag_max, length(v))

  # get into ggplot format
  a <- acf(v, lag.max = lag_max, plot = FALSE)
  acf <- as.vector(a$acf)
  df <- data.frame(lag = (1:length(acf))-1, ACF = acf)

  # produce plot
  plot1 <- ggplot(df) + theme_bw()
  plot1 <- plot1 + geom_segment(aes_(x = ~lag, y = 0, xend = ~lag, yend = ~ACF, colour = "col1"))
  plot1 <- plot1 + scale_colour_manual(values = col)
  plot1 <- plot1 + guides(colour = FALSE)
  plot1 <- plot1 + xlab("lag") + ylab("ACF")

  # return plot object
  return(plot1)
}

#------------------------------------------------
#' @title Produce MCMC density plot
#'
#' @description Produce MCMC density plot of the log-likelihood
#'
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param K value of K to plot.
#' @param rung which rung to plot. Defaults to the cold chain.
#' @param col colour of the trace.
#' @param phase plot the acf during the burnin or sampling phase.
#'
#' @import ggplot2
#' @export
#' 
#' @examples
#' \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' plot_density(project = p)

plot_density <- function(project, 
                         K = NULL, 
                         rung = NULL, 
                         col = "black", 
                         phase = "sampling") {

  # check inputs
  assert_custom_class(project, "rgeoprofile_project")
  if (!is.null(K)) {
    assert_single_pos_int(K, zero_allowed = FALSE)
  }
  if (!is.null(rung)) {
    assert_single_pos_int(rung)
  }
  assert_in(phase, c("burnin", "sampling"))

  # get output
  if(phase == "sampling"){
    loglike_mat <- get_output(project, "loglike_sampling", K, "raw")
  } else if(phase == "burnin"){
    loglike_mat <- get_output(project, "loglike_burnin", K, "raw")
  }  

  # use cold rung by default
  rungs <- ncol(loglike_mat)
  rung <- define_default(rung, rungs)
  assert_leq(rung, rungs)
  loglike <- as.vector(loglike_mat[,rung])

  # get into ggplot format
  df <- data.frame(v = loglike)

  # produce plot
  plot1 <- ggplot(df) + theme_bw() + xlab("log-likelihood")

  # produce plot
  #plot1 <- ggplot(df) + theme_bw()
  plot1 <- plot1 + geom_histogram(aes_(x = ~v, y = ~..density.., fill = "col1"), bins = 50)
  plot1 <- plot1 + scale_fill_manual(values = col)
  plot1 <- plot1 + guides(fill = FALSE)
  plot1 <- plot1 + ylab("density")

  # return plot object
  return(plot1)
}

#------------------------------------------------
#' @title Produce diagnostic plots of log-likelihood
#'
#' @description Produce diagnostic plots of the log-likelihood.
#'
#' @details For a value of K produce the auto-correlation, trace and density plots of the MCMC.
#'
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param K which value of K to plot.
#' @param rung which rung to plot. Defaults to the cold chain.
#' @param phase check the burnin or sampling phase of the MCMC chain
#' @param col colour of the trace.
#'
#' @import ggplot2
#' @importFrom gridExtra grid.arrange
#' @export
#' 
#' @examples
#' \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' plot_loglike_diagnostic(project = p, K = 2)

plot_loglike_diagnostic <- function(project, 
                                    K = NULL, 
                                    rung = NULL, 
                                    phase = "sampling", 
                                    col = "black") {
  
  # check inputs
  assert_custom_class(project, "rgeoprofile_project")
  if (!is.null(K)) {
    assert_single_pos_int(K, zero_allowed = FALSE)
  }
  if (!is.null(rung)) {
    assert_single_pos_int(rung)
  }
  assert_in(phase, c("sampling", "burnin"))
  
  # get active set and check non-zero
  s <- project$active_set
  if (s == 0) {
    stop("no active parameter set")
  }
  
  # set default K to first value with output
  null_output <- mapply(function(x) {is.null(x$raw$loglike_sampling)}, project$output$single_set[[s]]$single_K)
  if (all(null_output)) {
    stop("no loglike_sampling output for active parameter set")
  }
  if (is.null(K)) {
    K <- which(!null_output)[1]
    message(sprintf("using K = %s by default", K))
  }
  
  # get output
  mcmc_phase <- paste0("loglike_", phase)
  loglike_sampling <- get_output(project, mcmc_phase, K, "raw")
  
  # use cold rung by default
  rungs <- ncol(loglike_sampling)
  rung <- define_default(rung, rungs)
  assert_leq(rung, rungs)
  
  # produce individual diagnostic plots and add features
  plot1 <- plot_trace(project, K = K, rung = rung, col = col, phase = phase)
  plot1 <- plot1 + ggtitle("MCMC trace")
  
  plot2 <- plot_acf(project, K = K, rung = rung, col = col, phase = phase)
  plot2 <- plot2 + ggtitle("autocorrelation")
  
  plot3 <- plot_density(project, K = K, rung = rung, col = col, phase = phase)
  plot3 <- plot3 + ggtitle("density")
  
  # produce grid of plots
  ret <- gridExtra::grid.arrange(plot1, plot2, plot3, layout_matrix = rbind(c(1,1), c(2,3)))
}

#------------------------------------------------
#' @title Create dynamic map
#'
#' @description Create dynamic map
#'
#' @param map_type an index from 1 to 137 indicating the type of base map. The
#'   map types are taken from \code{leaflet::providers}. Defaults to "CartoDB".
#'
#' @import leaflet
#' @export
#' 
#' @examples
#' # Standard OSM format is given by map_type = 1, though
#' # this value can take anywhere between 1 and 137.
#' plot_map(map_type = 1)

plot_map <- function(map_type = 110) {

  # check inputs
  assert_in(map_type, 1:137, message = "map_type must be in 1:137")

  # produce plot
  myplot <- leaflet()
  myplot <-  addProviderTiles(myplot, leaflet::providers[[map_type]])

  # return plot object
  return(myplot)
}

#------------------------------------------------
#' @title Plot DIC over all K
#'
#' @description Plot DIC over all K for the current active parameter set.
#'
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#'
#' @import ggplot2
#' @export
#' 
#' @examples
#' \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' plot_DIC_gelman(p)

plot_DIC_gelman <- function(project) {

  # check inputs
  assert_custom_class(project, "rgeoprofile_project")

  # get active set and check non-zero
  s <- project$active_set
  if (s == 0) {
    stop("no active parameter set")
  }

  # get DIC values
  df <- project$output$single_set[[s]]$all_K$DIC_gelman

  # produce plot
  plot1 <- ggplot(data = df) + theme_bw()
  plot1 <- plot1 + geom_point(aes_(x = ~as.factor(K), y = ~DIC_gelman))
  plot1 <- plot1 + xlab("K") + ylab("DIC (Gelman)")

  return(plot1)
}

#------------------------------------------------
#' @title Produce Lorenz plot of hitscores
#'
#' @description Produce Lorenz plot of hitscores
#'
#' @details  The hs object is obtained from the \code{get_hitscores()} function.
#'
#' @param hs dataframe of hitscores.
#' @param col vector of group colours. Uses \code{more_colours()} by default.
#' @param counts optional vector of counts corresponding to each source. If
#'   specified, the y-axis is in terms of total counts found, rather than total
#'   sources found.
#'
#' @import ggplot2
#' @export
#' 
#' @examples
#' \dontshow{hs <- rgeoprofile_file("tutorial1_hitscore.rds")}
#' plot_lorenz(hs)

plot_lorenz <- function(hs, 
                        col = NULL, 
                        counts = NULL) {

  # check inputs
  assert_dataframe(hs)
  if (!is.null(counts)) {
    assert_pos_int(counts, zero_allowed = FALSE)
    assert_vector(counts)
    assert_length(counts, nrow(hs))
  }

  # drop lon/lat columns
  hs <- hs[ , !names(hs) %in% c("longitude", "latitude"), drop = FALSE]

  # get properties
  ns <- nrow(hs)
  hs_names <- colnames(hs)

  # set default colours
  col <- define_default(col, more_colours(ncol(hs)))

  # get sorted hitscores on x-axis
  x_list <- mapply(function(x) c(0, sort(x, na.last = NA), 100), as.list(hs), SIMPLIFY = FALSE)

  # get number of sources found on y axis
  if (is.null(counts)) {
    y_list <- mapply(function(x) {
      l <- length(x)
      c(0:(l-2), l-2)/ns*100
    }, x_list, SIMPLIFY = FALSE)
  } else {
    print("bar")
    y_list <- mapply(function(x) {
      l <- length(x)
      c(0, cumsum(counts), sum(counts))/sum(counts)*100
    }, x_list,  SIMPLIFY = FALSE)
  }

  # get number of points in each group
  z_list <- rep(hs_names, times = mapply(length, x_list))


  # get into ggplot format
  df <- data.frame(x = unlist(x_list), y = unlist(y_list), col = unlist(z_list))

  # make ggplot object
  plot1 <- ggplot(data = df, aes_(x = ~x, y = ~y, color = ~col)) + theme_bw()
  plot1 <- plot1 + theme_bw() + geom_point() + geom_line()
  plot1 <- plot1 + geom_abline(slope = 1, linetype = "dashed")
  plot1 <- plot1 + scale_colour_manual(values = col, name = "Search method")
  plot1 <- plot1 + scale_x_continuous(limits = c(0,100)) + scale_y_continuous(limits = c(0,100))
  plot1 <- plot1 + xlab("search effort (%)") + ylab("found (%)")

  return(plot1)
}

#------------------------------------------------
#' @title Add sentinel sites to dynamic map
#'
#' @description Add sentinel sites to dynamic map
#'
#' @param myplot dynamic map produced by \code{plot_map()} function
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param sentinel_radius the radius of sentinel sites. Taken from the active
#'   parameter set if unspecified.
#' @param fill whether to fill circles.
#' @param fill_colour colour of circle fill.
#' @param fill_opacity fill opacity.
#' @param border whether to add border to circles.
#' @param border_colour colour of circle borders.
#' @param border_weight thickness of circle borders.
#' @param border_opacity opacity of circle borders.
#' @param legend whether to add a legend for site count.
#' @param label whether to label sentinel sites with densities.
#' @param label_size size of the label.
#'
#' @import leaflet
#' @importFrom grDevices grey
#' @export
#' 
#' @examples
#' \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' plot1 <- plot_map()
#' plot1 <- overlay_sentinels(plot1, project = p, fill_opacity = 0.9, fill = TRUE,
#'                            fill_colour = c(grey(0.7), "red"), border = c(FALSE, TRUE),
#'                            border_colour = "black",border_weight = 0.5)
#' plot1

overlay_sentinels <- function(myplot,
                              project,
                              sentinel_radius = NULL,
                              fill = TRUE,
                              fill_colour = c(grey(0.5), "red"),
                              fill_opacity = 0.5,
                              border = FALSE,
                              border_colour = "black",
                              border_weight = 1,
                              border_opacity = 1.0,
                              legend = FALSE,
                              label = FALSE,
                              label_size = 15) {
  
  # check inputs
  assert_custom_class(myplot, "leaflet")
  assert_custom_class(project, "rgeoprofile_project")
  if (!is.null(sentinel_radius)) {
    assert_single_pos(sentinel_radius)
  }
  assert_logical(fill)
  assert_vector(fill)
  assert_in(length(fill), c(1,2))
  if (length(fill) == 1) {
    fill <- rep(fill, 2)
  }
  assert_string(fill_colour)
  assert_vector(fill_colour)
  # assert_in(length(fill_colour), c(1,2))
  if (length(fill_colour) == 1) {
    fill_colour <- rep(fill_colour, 2)
  }
  assert_single_pos(fill_opacity)
  assert_bounded(fill_opacity, 0, 1, inclusive_left = TRUE, inclusive_right = TRUE)
  assert_logical(border)
  assert_vector(border)
  assert_in(length(border), c(1,2))
  if (length(border) == 1) {
    border <- rep(border, 2)
  }
  assert_string(border_colour)
  assert_vector(border_colour)
  # assert_in(length(border_colour), c(1,2))
  if (length(border_colour) == 1) {
    border_colour <- rep(border_colour, 2)
  }
  assert_single_pos(border_opacity)
  assert_bounded(border_opacity, 0, 1, inclusive_left = TRUE, inclusive_right = TRUE)
  assert_logical(legend)
  assert_logical(label)
  assert_single_pos(label_size)
  
  # check for data
  df <- project$data$frame
  if (is.null(df)) {
    stop("no data loaded")
  }
  
  # get sentinel radius from active parameter set by default
  if (is.null(sentinel_radius)) {
    message("getting sentinel radius from active parameter set:")
    
    # get active set and check non-zero
    s <- project$active_set
    if (s == 0) {
      stop("no active parameter set")
    }
    
    # get sentinel radius
    sentinel_radius <- project$parameter_sets[[s]]$sentinel_radius
    message(sprintf("  sentinal radius = %skm", sentinel_radius))
  }
  
  # make circle attributes depend on counts
  n <- nrow(df)
  fill_vec <- rep(fill[1], n)
  
  fill_vec[df$counts > 0] <- fill[2]
  fill_colour_vec <- rep(fill_colour[1], n)
  fill_colour_vec[df$counts > 0] <- fill_colour[2]
  border_vec <- rep(border[1], n)
  
  border_vec[df$counts > 0] <- border[2]
  border_colour_vec <- rep(border_colour[1], n)
  border_colour_vec[df$counts > 0] <- border_colour[2]
  
  # add legend for sentinel site counts
  if (legend == TRUE) {
    site_dense__seq <- seq(0, max(df$counts[df$counts > 0]), 1)
    pal <- colorNumeric(palette = border_colour, domain = site_dense__seq)
    myplot <- addLegend(myplot, "topright", pal, values = site_dense__seq, title = "Site Count", opacity = 1)
  }
  
  # overlay circles
  myplot <- addCircles(myplot, lng = df$longitude, lat = df$latitude,
                      radius = sentinel_radius*1e3,
                      fill = fill_vec, fillColor = fill_colour_vec, fillOpacity = fill_opacity,
                      stroke = border_vec, color = border_colour_vec,
                      opacity = border_opacity, weight = border_weight)
                      
  # label sentinel site counts
  if (label == TRUE) {
    lab_size <- paste(label_size, "px", sep = "")
    myplot <- addLabelOnlyMarkers(myplot, lng = df$longitude, lat = df$latitude, 
                                  label = as.character(df$counts), 
                                  labelOptions = labelOptions(noHide = T, textOnly = TRUE,
                                  direction = "center", textsize = lab_size))                
  }

  # return plot object
  return(myplot)
}

#------------------------------------------------
#' @title Add trial sites to dynamic map
#'
#' @description Add trial sites to dynamic map
#'
#' @param myplot dynamic map produced by \code{plot_map()} function
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param fill whether to fill circles.
#' @param fill_colour colour of circle fill.
#' @param fill_opacity fill opacity.
#' @param border whether to add border to circles.
#' @param border_colour colour of circle borders.
#' @param border_weight thickness of circle borders.
#' @param border_opacity opacity of circle borders.
#' @param legend whether to add a legend for site count.
#' @param site_radius radius in Km shown at each site
#' @param plot_type plot trial sites as circles or piecharts
#'
#' @import leaflet
#' @importFrom grDevices grey
#' @export
#' 
#' @examples
#' #\dontshow{p <- rgeoprofile_file("tutorial2_project.rds")}
#' #plot1 <- plot_map()
#' #plot1 <- overlay_trial_sites(plot1, project = p, fill_opacity = 0.9, fill = TRUE,
#' #                           fill_colour = c(grey(0.7), "red"), border = c(FALSE, TRUE),
#' #                           border_colour = "black",border_weight = 0.5)
#' #plot1

overlay_trial_sites <- function(myplot,
                                project,
                                fill = TRUE,
                                fill_colour = c(grey(0.5), "red"),
                                fill_opacity = 0.5,
                                border = FALSE,
                                border_colour = "black",
                                border_weight = 1,
                                border_opacity = 1.0,
                                legend = FALSE,
                                site_radius = 20,
                                plot_type = "piecharts") {
  
  # check inputs
  assert_custom_class(myplot, "leaflet")
  assert_custom_class(project, "rgeoprofile_project")
  assert_logical(fill)
  assert_vector(fill)
  assert_in(length(fill), c(1,2))
  if (length(fill) == 1) {
    fill <- rep(fill, 2)
  }
  assert_string(fill_colour)
  assert_vector(fill_colour)
  # assert_in(length(fill_colour), c(1,2))
  if (length(fill_colour) == 1) {
    fill_colour <- rep(fill_colour, 2)
  }
  assert_single_pos(fill_opacity)
  assert_bounded(fill_opacity, 0, 1, inclusive_left = TRUE, inclusive_right = TRUE)
  assert_logical(border)
  assert_vector(border)
  assert_in(length(border), c(1,2))
  if (length(border) == 1) {
    border <- rep(border, 2)
  }
  assert_string(border_colour)
  assert_vector(border_colour)
  # assert_in(length(border_colour), c(1,2))
  if (length(border_colour) == 1) {
    border_colour <- rep(border_colour, 2)
  }
  assert_single_pos(border_opacity)
  assert_bounded(border_opacity, 0, 1, inclusive_left = TRUE, inclusive_right = TRUE)
  assert_logical(legend)
  assert_single_pos_int(site_radius)
  assert_in(plot_type, c("circles", "piecharts"))
  
  # check for data
  df <- project$data$frame
  if (is.null(df)) {
    stop("no data loaded")
  }
  
  if(plot_type == "circles")  {
  # make circle attributes depend on counts
  n <- nrow(df)
  fill_vec <- rep(fill[1], n)
  
  fill_vec[df$positive > 0] <- fill[2]
  fill_colour_vec <- rep(fill_colour[1], n)
  fill_colour_vec[df$positive > 0] <- fill_colour[2]
  border_vec <- rep(border[1], n)
  
  border_vec[df$positive > 0] <- border[2]
  border_colour_vec <- rep(border_colour[1], n)
  border_colour_vec[df$positive > 0] <- border_colour[2]
  
  # overlay circles
  myplot <- addCircles(myplot, 
                       lng = df$longitude, 
                       lat = df$latitude,
                       radius = site_radius,
                       fill = fill_vec, 
                       fillColor = fill_colour_vec, 
                       fillOpacity = fill_opacity,
                       stroke = border_vec, 
                       color = border_colour_vec,
                       opacity = border_opacity, 
                       weight = border_weight)
                        
  } else if(plot_type == "piecharts"){
    
  # get data into ggplot format
  lon <- project$data$frame$longitude
  lat <- project$data$frame$latitude
  positive <- project$data$frame$positive
  tested <- project$data$frame$tested
  # proportions <- positive/tested
  
  pie_size <- site_radius
  df <- data.frame(positive = positive, 
                   negative = tested - positive)

  # overlay pie charts
  myplot <- addMinicharts(myplot, lon, lat,
                          type = "pie",
                          chartdata = df,
                          colorPalette = c("red", "grey"),
                          width = pie_size,
                          transitionTime = 20)
  
  
  }
  return(myplot)
}


#------------------------------------------------
#' @title Add points to dynamic map
#'
#' @description Add points to dynamic map
#'
#' @param myplot dynamic map produced by \code{plot_map()} function.
#' @param lon,lat longitude and latitude of points.
#' @param col colour of points.
#' @param size size of points.
#' @param opacity opacity of points.
#'
#' @import leaflet
#' @export
#' 
#' @examples
#' \dontshow{mysim <- rgeoprofile_file("tutorial1_mysim.rds")}
#' all_records <- mysim$record$data_all
#' plot1 <- plot_map()
#' plot1 <- overlay_points(myplot = plot1, all_records$longitude, all_records$latitude)
#' show(plot1)

overlay_points <- function(myplot, 
                           lon, 
                           lat, 
                           col = "black", 
                           size = 1, 
                           opacity = 1.0) {

  # check inputs
  assert_custom_class(myplot, "leaflet")
  assert_numeric(lon)
  assert_vector(lon)
  assert_numeric(lat)
  assert_vector(lat)
  assert_same_length(lon, lat)
  assert_single_pos(size, zero_allowed = FALSE)
  assert_single_pos(opacity, zero_allowed = TRUE)
  assert_bounded(opacity)

  # add circle markers
  myplot <- addCircleMarkers(myplot, lng = lon, lat = lat, radius = size,
                             fillColor = col, stroke = FALSE, fillOpacity = opacity)

  # return plot object
  return(myplot)
}

#------------------------------------------------
#' @title Add spatial prior to dynamic map
#'
#' @description Add spatial prior to dynamic map
#'
#' @param myplot dynamic map produced by \code{plot_map()} function.
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param col set of plotting colours.
#' @param opacity opacity of spatial prior.
#' @param smoothing what level of smoothing to apply to spatial prior Smoothing
#'   is applied using the \code{raster} function \code{disaggregate}, with
#'   \code{method = "bilinear"}.
#'
#' @import leaflet
#' @importFrom raster disaggregate
#' @export
#' 
#' @examples
#' \dontshow{library(silverblaze)}
#' \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' plot1 <- plot_map()
#' plot1 <- overlay_spatial_prior(myplot = plot1, project = p)
#' plot1

overlay_spatial_prior <- function(myplot,
                                  project,
                                  col = col_hotcold(),
                                  opacity = 0.8,
                                  smoothing = 1) {

  # check inputs
  assert_custom_class(myplot, "leaflet")
  assert_custom_class(project, "rgeoprofile_project")
  assert_string(col)
  assert_bounded(opacity, left = 0, right = 1, inclusive_left = TRUE, inclusive_right = TRUE)
  assert_single_pos(smoothing)
  assert_greq(smoothing, 1.0)

  # get active set and check non-zero
  s <- project$active_set
  if (s == 0) {
    stop("no active parameter set")
  }

  # get spatial prior
  spatial_prior <- project$parameter_sets[[s]]$spatial_prior

  # apply smoothing
  if (smoothing > 1.0) {
    spatial_prior <- raster::disaggregate(spatial_prior, smoothing, method = "bilinear")
  }

  # overlay raster
  myplot <- leaflet::addRasterImage(myplot, x = spatial_prior, colors = col, opacity = opacity)

  # return plot object
  return(myplot)
}

#------------------------------------------------
#' @title Add geoprofile to dynamic map
#'
#' @description Add geoprofile to dynamic map
#'
#' @param myplot dynamic map produced by \code{plot_map()} function.
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param K which value of K to plot.
#' @param source which source to plot. If NULL then plot combined surface.
#' @param realised if TRUE then plot surface for realised sources only.
#' @param threshold what proportion of geoprofile to plot.
#' @param col set of plotting colours.
#' @param opacity opacity of geoprofile (that is not invisible due to being
#'   below threshold).
#' @param smoothing what level of smoothing to apply to geoprofile. Smoothing is
#'   applied using the \code{raster} function \code{disaggregate}, with
#'   \code{method = "bilinear"}.
#' @param legend Set to TRUE or FALSE, this will add a hitscore legend to the
#'   plot.
#'
#' @import leaflet
#' @importFrom grDevices grey
#' @export
#' 
#' @examples
#' \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' plot1 <- plot_map()
#' plot1 <-overlay_geoprofile(myplot = plot1, project = p, K = 2, source = NULL, threshold = 0.5)
#' plot1

overlay_geoprofile <- function(myplot,
                               project,
                               K = NULL,
                               source = NULL,
                               realised = FALSE,
                               threshold = 0.1,
                               col = col_hotcold(),
                               opacity = 0.8,
                               smoothing = 1,
                               legend  = FALSE) {
  
  # check inputs
  assert_custom_class(myplot, "leaflet")
  assert_custom_class(project, "rgeoprofile_project")
  if (!is.null(source)) {
    assert_single_pos_int(source, zero_allowed = FALSE)
  }
  assert_single_logical(realised)
  assert_bounded(threshold, left = 0, right = 1, inclusive_left = TRUE, inclusive_right = TRUE)
  assert_string(col)
  assert_bounded(opacity, left = 0, right = 1, inclusive_left = TRUE, inclusive_right = TRUE)
  assert_single_pos(smoothing)
  assert_greq(smoothing, 1.0)
  assert_single_logical(legend)
  
  # extract geoprofile
  if (realised) {
    geoprofile <- get_output(project, "geoprofile_realised", K = K)
  } else {
    if (is.null(source)) {
      geoprofile <- get_output(project, "geoprofile", K = K)
    } else {
      assert_leq(source, K)
      geoprofile_split <- get_output(project, "geoprofile_split", K = K)
      geoprofile <- geoprofile_split[[source]]
    }
  }
  
  # apply smoothing
  if (smoothing > 1.0) {
    geoprofile <- disaggregate(geoprofile, smoothing, method = "bilinear")
  }
  
  # apply threshold
  geoprofile_mat <- matrix(values(geoprofile), nrow(geoprofile), byrow = TRUE)
  geoprofile_mat[geoprofile_mat > threshold*100] <- NA
  geoprofile <- setValues(geoprofile, geoprofile_mat)
  
  # overlay raster
  myplot <- addRasterImage(myplot, x = geoprofile, colors = col, opacity = opacity, project = FALSE)
  
  # add bounding rect
  myplot <- addRectangles(myplot, xmin(geoprofile), ymin(geoprofile),
                          xmax(geoprofile), ymax(geoprofile),
                          fill = FALSE, weight = 2, color = grey(0.2))
  
  # add hitscore legend
  if (legend == TRUE) {
    hitscore_sequence <- seq(0, threshold, threshold / (length(col) - 1))
    pal <- colorNumeric(palette = col, domain = hitscore_sequence)
    myplot <- addLegend(myplot, "bottomright", pal = pal, values = hitscore_sequence, title = "Hit score", opacity = 1)
  }
  
  # return plot object
  return(myplot)
}

#------------------------------------------------
#' @title Add posterior probability surface to dynamic map
#'
#' @description Add posterior probability surface to dynamic map
#'
#' @param myplot dynamic map produced by \code{plot_map()} function.
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param K which value of K to plot.
#' @param source which source to plot. If NULL then plot combined surface.
#' @param realised if TRUE then plot surface for realised sources only.
#' @param threshold what proportion of posterior probability surface to plot.
#' @param col set of plotting colours.
#' @param opacity opacity of posterior probability surface (that is not
#'   invisible due to being below threshold).
#' @param smoothing what level of smoothing to apply to posterior probability
#'   surface. Smoothing is applied using the \code{raster} function
#'   \code{disaggregate}, with \code{method = "bilinear"}.
#' @param legend whether or not a legend is plotted
#'
#' @import leaflet
#' @importFrom grDevices grey
#' @export
#' 
#' @examples
#' \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' plot1 <- plot_map()
#' plot1 <-overlay_surface(myplot = plot1, project = p, K = 2, source = NULL, threshold = 0.5)
#' plot1

overlay_surface <- function(myplot,
                            project,
                            K = NULL,
                            source = NULL,
                            realised = FALSE,
                            threshold = 0.1,
                            col = rev(col_hotcold()),
                            opacity = 0.8,
                            smoothing = 1.0,
                            legend = FALSE) {
  
  # check inputs
  assert_custom_class(myplot, "leaflet")
  assert_custom_class(project, "rgeoprofile_project")
  if (!is.null(source)) {
    assert_single_pos_int(source, zero_allowed = FALSE)
  }
  assert_single_logical(realised)
  assert_bounded(threshold, left = 0, right = 1, inclusive_left = TRUE, inclusive_right = TRUE)
  assert_string(col)
  assert_bounded(opacity, left = 0, right = 1, inclusive_left = TRUE, inclusive_right = TRUE)
  assert_single_pos(smoothing)
  assert_greq(smoothing, 1.0)
  assert_single_logical(legend)
  
  
  # extract geoprofile
  if (realised) {
    prob_surface <- get_output(project, "prob_surface_realised", K = K)
  } else {
    if (is.null(source)) {
      prob_surface <- get_output(project, "prob_surface", K = K)
    } else {
      assert_leq(source, K)
      prob_surface_split <- get_output(project, "prob_surface_split", K = K)
      prob_surface <- prob_surface[[source]]
    }
  }
  
  # apply smoothing
  if (smoothing > 1.0) {
    prob_surface <- disaggregate(prob_surface, smoothing, method = "bilinear")
  }
  
  # apply threshold
  prob_surface_mat <- matrix(values(prob_surface), nrow(prob_surface), byrow = TRUE)
  sorted_prob_mat <- sort(prob_surface_mat, decreasing = TRUE, index.return = TRUE)
  threshold_final <- max(which(cumsum(sorted_prob_mat$x) < threshold))
  prob_surface_mat[sorted_prob_mat$ix[(threshold_final:length(prob_surface_mat))]] <- NA

  prob_surface <- setValues(prob_surface, prob_surface_mat)
  
  # overlay raster
  myplot <- addRasterImage(myplot, x = prob_surface, colors = col, opacity = opacity)
  
  # add bounding rect
  myplot <- addRectangles(myplot, xmin(prob_surface), ymin(prob_surface),
                          xmax(prob_surface), ymax(prob_surface),
                          fill = FALSE, weight = 2, color = grey(0.2))
  
  # add legend
  if (legend == TRUE) {
    prob_sequence <- seq(1 - threshold, 1, threshold/(length(col) - 1))
    pal <- colorNumeric(palette = col, domain = prob_sequence)
    myplot <- addLegend(myplot, "bottomright", pal = pal, values = prob_sequence, title = "Posterior\nprobability", opacity = 1)
  }                        
  
  # return plot object
  return(myplot)
}

#------------------------------------------------
#' @title Add risk surface to dynamic map
#'
#' @description Add posterior probability surface to dynamic map
#'
#' @param myplot dynamic map produced by \code{plot_map()} function.
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param K which value of K to plot.
#' @param source which source to plot. If NULL then plot combined surface.
#' @param threshold what proportion of posterior probability surface to plot.
#' @param col set of plotting colours.
#' @param opacity opacity of posterior probability surface (that is not
#'   invisible due to being below threshold).
#' @param smoothing what level of smoothing to apply to posterior probability
#'   surface. Smoothing is applied using the \code{raster} function
#'   \code{disaggregate}, with \code{method = "bilinear"}.
#' @param legend whether or not a legend is plotted
#' @param iterations the number of random parameter sets to generate and combine
#'   risk maps
#'
#' @import leaflet
#' @importFrom grDevices grey
#' @importFrom raster ncell
#' @export
#' 
#' @examples
#' \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' plot1 <- plot_map()
#' plot1 <-overlay_surface(myplot = plot1, project = p, K = 2, source = NULL, threshold = 0.5)
#' plot1

overlay_risk_map <- function(myplot,
                             project,
                             K = NULL,
                             source = NULL,
                             threshold = 0.1,
                             col = rev(col_hotcold()),
                             opacity = 0.8,
                             smoothing = 1.0,
                             legend = FALSE,
                             iterations = 50) {
  
  # check inputs
  assert_custom_class(myplot, "leaflet")
  assert_custom_class(project, "rgeoprofile_project")
  if (!is.null(source)) {
    assert_single_pos_int(source, zero_allowed = FALSE)
  }
  assert_bounded(threshold, left = 0, right = 1, inclusive_left = TRUE, inclusive_right = TRUE)
  assert_string(col)
  assert_bounded(opacity, left = 0, right = 1, inclusive_left = TRUE, inclusive_right = TRUE)
  assert_single_pos(smoothing)
  assert_greq(smoothing, 1.0)
  assert_single_logical(legend)
  assert_single_pos_int(iterations)
  
  # get active set and check non-zero
  s <- project$active_set
  if (s == 0) {
    stop("no active parameter set")
  }
  
  # create risk map 
  
  # extract param values for the model run
  longs <- get_output(project, name = "source_lon_sampling", type = "raw", K = K)
  lats <- get_output(project, name = "source_lat_sampling", type = "raw", K = K)
  sigmas <- get_output(project, name = "sigma_sampling", type = "raw", K = K)
  expected_popsizes <- get_output(project, name = "expected_popsize_sampling", type = "raw", K = K)
  
  # get spatial prior
  spatial_prior <- project$parameter_sets[[s]]$spatial_prior
  raster_dims <- dim(spatial_prior)
  
  # get grid of raster cell locations
  grid_extent <- extent(spatial_prior)
  longrid <- seq(grid_extent[1], grid_extent[2], l = raster_dims[2])
  latgrid <- seq(grid_extent[3], grid_extent[4], l = raster_dims[1])
  grid <- expand.grid(longrid, latgrid)
  grid <- data.frame(longitude = grid$Var1, latitude = grid$Var2)
  
  # how many iterations will this be done for
  samples <- nrow(longs)
  chosen_iterations <- sample(seq_len(samples), iterations)
  ncells <- raster::ncell(spatial_prior)
  risk_map_matrix <- matrix(NA, ncol = iterations, nrow = ncells)
  
  for (i in seq_len(iterations)) {
    
    # get distances from each source to every grid cell
    gc_dist <- mapply(function(x, y) {
        lonlat_to_bearing(x, y, grid$longitude, grid$latitude)$gc_dist
      }, x = longs[chosen_iterations[i],], y = lats[chosen_iterations[i],])
    
    # get heights of each cell on the mixture of normals
    densities <- dnorm(gc_dist, 0, sigmas[chosen_iterations[i],], FALSE) * dnorm(0, 0, sigmas[chosen_iterations[i],], FALSE)
    
    # multiply by expected popsizes and take weighted average over sources
    hazard_values <- rowSums(sweep(densities, 2, expected_popsizes[chosen_iterations[i],], "*"))
    
    # transform to [0,1] domain
    binom_prob <- hazard_values/(hazard_values + 1)
    
    # store risk map
    risk_map_matrix[,i] <- binom_prob
  }
  
  # average over all risk_map values
  risk_map <- apply(risk_map_matrix, 1, mean)
  risk_raster <- spatial_prior # (these values will be overwritten)
  
  # manipulate the values of the risk map to conform to the form a raster 
  # receives values
  manipulated_value <- t(matrix(risk_map, raster_dims[1], raster_dims[2], byrow = TRUE))
  risk_raster <- setValues(risk_raster, apply(manipulated_value,1,rev))
    
  # apply smoothing
  if (smoothing > 1.0) {
    risk_mat <- disaggregate(risk_mat, smoothing, method = "bilinear")
  }  
  
  # threshold = 1
  # apply threshold
  risk_map_mat <- matrix(values(risk_raster), nrow(risk_raster), byrow = TRUE)
  threshold_final <- sort(risk_map_mat, decreasing = TRUE)[ceiling(length(risk_map_mat)*threshold)]
  risk_map_mat[risk_map_mat < threshold_final] <- NA
  risk_map <- setValues(risk_raster, risk_map_mat)
  
  # overlay raster
  myplot <- addRasterImage(myplot, x = risk_raster, colors = col, opacity = opacity)

  # add bounding rect
  myplot <- addRectangles(myplot, xmin(risk_raster), ymin(risk_raster),
                          xmax(risk_raster), ymax(risk_raster),
                          fill = FALSE, weight = 2, color = grey(0.2))
  
  # add legend
  if(legend == TRUE) {
  prob_sequence <- seq(0, threshold, threshold/(length(col) - 1))
  pal <- colorNumeric(palette = col, domain = prob_sequence)
  myplot <- addLegend(myplot, "bottomright", pal = pal, values = prob_sequence, title = "Trial\nProbability", opacity = 1)
  }          

  # return plot object
  return(myplot)
}

#------------------------------------------------
#' @title Add sources to dynamic map
#'
#' @description Add sources to dynamic map
#'
#' @param myplot dynamic map produced by \code{plot_map()} function.
#' @param lon,lat longitude and latitude of sources.
#' @param icon_url what image to use for the icon.
#' @param icon_width,icon_height the width and height of the icon.
#' @param icon_anchor_x,icon_anchor_y the coordinates of the "tip" of the icon (relative to
#'   its top left corner, i.e. the top left corner means \code{icon_anchor_x =
#'   0} and \code{icon_anchor_y = 0}), and the icon will be aligned so that this
#'   point is at the marker's geographical location.
#'
#' @import leaflet
#' @export
#' 
#' @examples
#' \dontshow{mysim <- rgeoprofile_file("tutorial1_mysim.rds")}
#' source_locations <- mysim$record$true_source
#' plot1 <- plot_map()
#' plot1 <- overlay_sources(myplot = plot1,
#'                          lon = source_locations$longitude,
#'                          lat = source_locations$latitude)
#' plot1

overlay_sources <- function(myplot,
                            lon,
                            lat,
                            icon_url = NULL,
                            icon_width = 20,
                            icon_height = 20,
                            icon_anchor_x = 10,
                            icon_anchor_y = 10) {
  
  # check inputs
  assert_custom_class(myplot, "leaflet")
  assert_numeric(lon)
  assert_vector(lon)
  assert_numeric(lat)
  assert_vector(lat)
  assert_same_length(lon, lat)
  if (is.null(icon_url)) {
    icon_url <- "https://github.com/Michael-Stevens-27/silverblaze/raw/master/R_ignore/icons/black_cross.png"
  }
  assert_single_string(icon_url)
  assert_single_pos_int(icon_width)
  assert_single_pos_int(icon_height)
  assert_single_pos_int(icon_anchor_x)
  assert_single_pos_int(icon_anchor_y)
  
  # create custom icon
  source_icon <- makeIcon(iconUrl = icon_url, iconWidth = icon_width, iconHeight = icon_height,
                          iconAnchorX = icon_anchor_x, iconAnchorY = icon_anchor_y)
  
  # add custom markers
  myplot <- addMarkers(myplot, lng = lon, lat = lat, icon = source_icon)
  
  # return plot object
  return(myplot)
}

#------------------------------------------------
#' @title Add pie charts to dynamic map
#'
#' @description Add pie charts to dynamic map
#'
#' @param myplot dynamic map produced by \code{plot_map()} function.
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param K which value of K to plot.
#' @param min_size,max_size lower and upper limits on the size of pie charts.
#' @param col segment colours.
#'
#' @import leaflet
#' @import leaflet.minicharts
#' @export
#' 
#' @examples
#' \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' plot1 <- plot_map()
#' plot1 <- overlay_piecharts(myplot = plot1, project = p, K = 2)
#' plot1

overlay_piecharts <- function(myplot,
                              project,
                              K = NULL,
                              min_size = 10,
                              max_size = 30,
                              col = NULL) {

  # check inputs
  assert_custom_class(myplot, "leaflet")
  assert_custom_class(project, "rgeoprofile_project")

  # get output
  qmatrix <- get_output(project, "qmatrix", K)
  K <- ncol(qmatrix)

  # set default colours from K
  col <- define_default(col, more_colours(K))

  # check correct number of colours
  assert_length(col, K)

  # get data into ggplot format
  w <- which(!is.na(qmatrix[,1]))
  lon <- project$data$frame$longitude[w]
  lat <- project$data$frame$latitude[w]
  counts <- project$data$frame$counts[w]
  pie_size <- min_size + counts/max(counts)*(max_size - min_size)
  df <- round(qmatrix[w,], digits = 3)

  # overlay pie charts
  myplot <- addMinicharts(myplot, lon, lat,
                          type = "pie",
                          chartdata = df,
                          colorPalette = col,
                          width = pie_size,
                          transitionTime = 20)

  return(myplot)
}

#------------------------------------------------
#' @title Add ring-search geoprofile to dynamic map
#'
#' @description Add ring-search geoprofile to dynamic map.
#'
#' @param myplot dynamic map produced by \code{plot_map()} function.
#' @param project an RgeoProfile project, as produced by the function
#'   \code{rgeoprofile_project()}.
#' @param threshold what proportion of geoprofile to plot.
#' @param col set of plotting colours.
#' @param opacity opacity of geoprofile (that is not invisible due to being
#'   below threshold).
#' @param smoothing what level of smoothing to apply to geoprofile. Smoothing is
#'   applied using the \code{raster} function \code{disaggregate}, with
#'   \code{method = "bilinear"}.
#'
#' @import leaflet
#' @importFrom grDevices grey
#' @export
#' 
#' @examples
#' \dontshow{p <- rgeoprofile_file("tutorial1_project.rds")}
#' plot1 <- plot_map()
#' plot1 <- overlay_ringsearch(myplot = plot1, project = p)
#' show(plot1)

overlay_ringsearch <- function(myplot,
                               project,
                               threshold = 0.1,
                               col = col_hotcold(),
                               opacity = 0.8,
                               smoothing = 1) {

  # check inputs
  assert_custom_class(myplot, "leaflet")
  assert_custom_class(project, "rgeoprofile_project")
  assert_single_numeric(threshold)
  assert_bounded(threshold, left = 0, right = 1, inclusive_left = TRUE, inclusive_right = TRUE)
  assert_string(col)
  assert_single_numeric(opacity)
  assert_bounded(opacity, left = 0, right = 1, inclusive_left = TRUE, inclusive_right = TRUE)
  assert_single_pos(smoothing)
  assert_greq(smoothing, 1.0)

  # get active set and check non-zero
  s <- project$active_set
  if (s == 0) {
    stop("no active parameter set")
  }

  # extract ringsearch output
  ringsearch <- project$output$single_set[[s]]$all_K$ringsearch
  if (is.null(ringsearch)) {
    stop("no ringsearch output for active set")
  }

  # apply smoothing
  if (smoothing > 1.0) {
    ringsearch <- disaggregate(ringsearch, smoothing, method = "bilinear")
  }

  # apply threshold
  ringsearch_mat <- matrix(values(ringsearch), nrow(ringsearch), byrow = TRUE)
  ringsearch_mat[ringsearch_mat > threshold*100] <- NA
  ringsearch <- setValues(ringsearch, ringsearch_mat)

  # overlay raster
  myplot <- addRasterImage(myplot, x = ringsearch, colors = col, opacity = opacity)

  # add bounding rect
  myplot <- addRectangles(myplot, xmin(ringsearch), ymin(ringsearch),
                          xmax(ringsearch), ymax(ringsearch),
                          fill = FALSE, weight = 2, color = grey(0.2))

  # return plot object
  return(myplot)
}
Michael-Stevens-27/silverblaze documentation built on May 28, 2021, 5:47 p.m.